diff --git a/llm/testing/__init__.py b/llm/testing/__init__.py new file mode 100644 index 00000000..60e1fcaa --- /dev/null +++ b/llm/testing/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Testing utilities for pypto-lib (accuracy comparison, regression harnesses, ...).""" diff --git a/llm/testing/hf_compare/__init__.py b/llm/testing/hf_compare/__init__.py new file mode 100644 index 00000000..1436d557 --- /dev/null +++ b/llm/testing/hf_compare/__init__.py @@ -0,0 +1,59 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Hugging Face comparison framework. + +This package provides a pluggable harness for verifying that PyPTO kernels / +scopes / end-to-end models produce numerically equivalent outputs to a +reference implementation (typically a Hugging Face module on CPU). + +Top-level entry points: + + from llm.testing.hf_compare import ( + ComparisonCase, register_case, get_case, list_cases, + ) + +Cases live under ``llm/testing/hf_compare/cases/`` and are registered with +``@register_case("model.scope")``. The CLI in ``__main__.py`` discovers them +by name. +""" +from .base import ( + ChainStep, + ChainedComparisonCase, + CompareReport, + ComparisonCase, + InputSpec, + OutputSelector, + ReferenceModel, + SelectorResult, + TargetModel, + TensorSpec, + Tolerance, + WeightAdapter, + get_case, + list_cases, + register_case, +) + +__all__ = [ + "ChainStep", + "ChainedComparisonCase", + "CompareReport", + "ComparisonCase", + "InputSpec", + "OutputSelector", + "ReferenceModel", + "SelectorResult", + "TargetModel", + "TensorSpec", + "Tolerance", + "WeightAdapter", + "get_case", + "list_cases", + "register_case", +] diff --git a/llm/testing/hf_compare/__main__.py b/llm/testing/hf_compare/__main__.py new file mode 100644 index 00000000..d47e29e3 --- /dev/null +++ b/llm/testing/hf_compare/__main__.py @@ -0,0 +1,66 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""CLI for the HF comparison framework. + +Usage: + python -m llm.testing.hf_compare list + python -m llm.testing.hf_compare run qwen3_14b.decode \\ + --hf-model-path /path/to/Qwen3-14B --platform a2a3 [--cpu-only] +""" +from __future__ import annotations + +import argparse +import json +import sys + +from .base import get_case, list_cases + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser(prog="hf_compare") + sub = parser.add_subparsers(dest="cmd", required=True) + + sub.add_parser("list", help="List registered cases.") + + run = sub.add_parser("run", help="Run one case.") + run.add_argument("case", help="Registered case name (see `list`).") + run.add_argument("--json", action="store_true", help="Print JSON report.") + run.add_argument( + "--kwarg", "-k", action="append", default=[], + help="Extra kwarg for the case factory, e.g. -k hf_model_path=/data/...", + ) + + args = parser.parse_args(argv) + # Autodiscovery is lazy: list_cases() / get_case() trigger it internally. + + if args.cmd == "list": + for name in list_cases(): + print(name) + return 0 + + if args.cmd == "run": + kwargs: dict[str, str] = {} + for kv in args.kwarg: + if "=" not in kv: + parser.error(f"--kwarg must be key=value, got {kv!r}") + k, v = kv.split("=", 1) + kwargs[k] = v + case = get_case(args.case, **kwargs) + report = case.run() + if args.json: + print(json.dumps(report.to_json(), indent=2)) + else: + print(report.summary()) + return 0 if report.passed else 1 + + return 2 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/llm/testing/hf_compare/base.py b/llm/testing/hf_compare/base.py new file mode 100644 index 00000000..65bdfe23 --- /dev/null +++ b/llm/testing/hf_compare/base.py @@ -0,0 +1,327 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Core abstractions for the HF comparison framework. + +A ComparisonCase bundles four pluggable pieces: + - ReferenceModel: the HF-side callable that produces golden outputs. + - TargetModel: the PyPTO-side callable under test (kernel / scope / e2e). + - WeightAdapter: maps HF state_dict entries onto tensors the target expects. + - InputSpec: declarative description of the inputs fed to both sides. + +The runner feeds identical inputs to both sides, collects named outputs, and +uses Comparator + Tolerance to produce a structured CompareReport. + +Note: this module intentionally has zero dependencies on pypto / transformers / +safetensors. Concrete adapters and targets that need them live in sibling +modules and are imported on demand. +""" +from __future__ import annotations + +from collections.abc import Callable, Mapping +from dataclasses import dataclass, field +from typing import Any, Protocol, runtime_checkable + +import torch + + +# --------------------------------------------------------------------------- +# Tensor I/O descriptors +# --------------------------------------------------------------------------- +@dataclass(frozen=True) +class TensorSpec: + """Shape + dtype + (optional) value generator for one input tensor.""" + + shape: tuple[int, ...] + dtype: torch.dtype + # How to synthesize values. None -> zeros. + sampler: Callable[[tuple[int, ...], torch.dtype, torch.Generator], torch.Tensor] | None = None + # Optional semantic tag so adapters/selectors can identify the tensor + # without relying on its dict key (e.g. "kv_cache", "rope_table"). + role: str | None = None + + +@dataclass +class InputSpec: + """Declarative input bundle, evaluated once and shared by ref & target.""" + + tensors: dict[str, TensorSpec] + seed: int = 0 + + def materialize(self, device: torch.device | str = "cpu") -> dict[str, torch.Tensor]: + gen = torch.Generator(device="cpu").manual_seed(self.seed) + out: dict[str, torch.Tensor] = {} + for name, spec in self.tensors.items(): + if spec.sampler is not None: + t = spec.sampler(spec.shape, spec.dtype, gen) + else: + t = torch.zeros(spec.shape, dtype=spec.dtype) + out[name] = t.to(device) if str(device) != "cpu" else t + return out + + +# --------------------------------------------------------------------------- +# Weight adapter +# --------------------------------------------------------------------------- +@runtime_checkable +class WeightAdapter(Protocol): + """Transforms an HF state_dict slice into kernel-ready tensors. + + Implementations can be declarative (see weight_adapter.DictAdapter) or + fully custom for fused / sharded / MoE-style layouts. + """ + + def adapt(self, hf_state: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: ... + + # Optional inverse direction: rebuild an HF-shaped tensor from kernel + # buffers. Useful when the target owns KV cache or fused weights and the + # reference needs them in HF layout. + def unadapt(self, kernel_state: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: # noqa: ARG002 + return {} + + +# --------------------------------------------------------------------------- +# Reference (HF side) +# --------------------------------------------------------------------------- +@runtime_checkable +class ReferenceModel(Protocol): + """Golden-reference callable. + + The returned dict's keys become the right-hand side of OutputSelector + pairs. A reference can wrap any of: + - an HF nn.Module (e.g. Qwen3DecoderLayer) + - a subset of one (e.g. just attention scope1) + - any torch function (for new ops with no HF counterpart) + """ + + name: str + + def prepare(self, hf_state: Mapping[str, torch.Tensor]) -> None: ... + + def forward(self, inputs: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: ... + + +# --------------------------------------------------------------------------- +# Target (PyPTO side) +# --------------------------------------------------------------------------- +@runtime_checkable +class TargetModel(Protocol): + """The system under test. + + ``prepare()`` is called once and may compile a pypto program (results + should be cached so re-running the same case is cheap). ``run()`` is + invoked per case and returns the dict to be compared. + """ + + name: str + + def prepare(self) -> None: ... + + def run( + self, + inputs: Mapping[str, torch.Tensor], + weights: Mapping[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: ... + + def teardown(self) -> None: ... + + +# --------------------------------------------------------------------------- +# Output selection & tolerance +# --------------------------------------------------------------------------- +SliceExpr = Any +"""A slice, int, tuple of them, or ``Ellipsis`` for full-range selection.""" + + +@dataclass(frozen=True) +class OutputSelector: + """Which output(s) to compare, plus optional reshape/slice.""" + + name: str # report label, e.g. "layer_out" + ref_key: str # key in ReferenceModel.forward() output + tgt_key: str # key in TargetModel.run() output + slice_: SliceExpr | None = None # applied to BOTH tensors before compare + cast_to: torch.dtype = torch.float32 # cast before computing metrics + # Optional post-processing hook applied AFTER slicing (e.g. gather written + # KV rows or transpose to a common layout). Run independently on ref/tgt. + postprocess: Callable[[torch.Tensor], torch.Tensor] | None = None + + +@dataclass(frozen=True) +class Tolerance: + atol: float = 5e-3 + rtol: float = 5e-3 + # Metrics to compute and include in the report. Unknown names ignored. + metrics: tuple[str, ...] = ("max_abs", "mean_abs", "max_rel", "cosine", "pass_rate") + # pass_rate: fraction of elements satisfying |a-b| <= atol + rtol*|b|. + pass_rate_threshold: float = 0.999 + # How many worst-offender entries to record per selector. + worst_offenders: int = 5 + + +# --------------------------------------------------------------------------- +# Report +# --------------------------------------------------------------------------- +@dataclass +class SelectorResult: + selector: str + metrics: dict[str, float] + passed: bool + # (flat_index, ref_val, tgt_val) tuples for the largest absolute diffs. + worst_offenders: list[tuple[int, float, float]] = field(default_factory=list) + # Optional free-form note (e.g. shape mismatch, cast applied, ...). + note: str = "" + + +@dataclass +class CompareReport: + case_name: str + passed: bool + results: list[SelectorResult] + # Free-form diagnostics (compile time, device, versions, env, ...). + meta: dict[str, Any] = field(default_factory=dict) + + def to_json(self) -> dict[str, Any]: + return { + "case": self.case_name, + "passed": self.passed, + "results": [ + { + "selector": r.selector, + "passed": r.passed, + "metrics": r.metrics, + "worst_offenders": r.worst_offenders, + "note": r.note, + } + for r in self.results + ], + "meta": self.meta, + } + + def summary(self) -> str: + status = "PASS" if self.passed else "FAIL" + lines = [f"[{status}] case={self.case_name}"] + for r in self.results: + tag = "PASS" if r.passed else "FAIL" + metric_str = " ".join(f"{k}={v:.4e}" for k, v in r.metrics.items()) + lines.append(f" [{tag}] {r.selector:24s} {metric_str}") + if r.note: + lines.append(f" note: {r.note}") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Comparison case (top-level) +# --------------------------------------------------------------------------- +# HF weight source: a path on disk, a preloaded state dict, or a callable +# returning one. Concrete loaders live in ``weight_adapter`` / case modules. +HFWeightSource = ( + str + | Mapping[str, torch.Tensor] + | Callable[[], Mapping[str, torch.Tensor]] + | None +) + + +@dataclass +class ComparisonCase: + name: str + reference: ReferenceModel + target: TargetModel + input_spec: InputSpec + weight_adapter: WeightAdapter + selectors: list[OutputSelector] + tolerance: Tolerance = field(default_factory=Tolerance) + hf_weights: HFWeightSource = None + # Optional pre/post hooks for diagnostics or in-place tweaks. + on_inputs: Callable[[dict[str, torch.Tensor]], None] | None = None + on_report: Callable[[CompareReport], None] | None = None + + def run(self) -> CompareReport: + """Drive the comparison. Implementation lives in runner.py.""" + # Imported lazily to keep base.py free of heavy dependencies. + from .runner import run_case + + return run_case(self) + + +# --------------------------------------------------------------------------- +# Chained cases (scope1 -> scope2 -> scope3) +# --------------------------------------------------------------------------- +@dataclass +class ChainStep: + case: ComparisonCase + # Map produced reference outputs onto the next step's InputSpec tensor + # names. ``{ref_output_key: next_input_name}``. + forward_map: dict[str, str] = field(default_factory=dict) + + +@dataclass +class ChainedComparisonCase: + """Run multiple cases where step N's reference outputs feed step N+1.""" + + name: str + steps: list[ChainStep] + stop_on_fail: bool = True + + def run(self) -> list[CompareReport]: + from .runner import run_chain + + return run_chain(self) + + +# --------------------------------------------------------------------------- +# Registry (so CLI can discover cases by name) +# --------------------------------------------------------------------------- +_CASE_REGISTRY: dict[str, Callable[..., ComparisonCase]] = {} + + +def register_case( + name: str, +) -> Callable[[Callable[..., ComparisonCase]], Callable[..., ComparisonCase]]: + """Decorator: ``@register_case("qwen3_14b.decode")`` on a factory function. + + The decorated function should return a fully-constructed ComparisonCase + when called (typically with optional kwargs like ``hf_model_path=...``). + """ + + def deco(fn: Callable[..., ComparisonCase]) -> Callable[..., ComparisonCase]: + if name in _CASE_REGISTRY: + raise ValueError(f"Case {name!r} already registered.") + _CASE_REGISTRY[name] = fn + return fn + + return deco + + +def _autodiscover_cases() -> None: + """Import every module in ``llm.testing.hf_compare.cases`` to register cases.""" + import importlib + import pkgutil + + try: + cases_pkg = importlib.import_module("llm.testing.hf_compare.cases") + except ImportError: + return + for modinfo in pkgutil.iter_modules(cases_pkg.__path__): + importlib.import_module(f"llm.testing.hf_compare.cases.{modinfo.name}") + + +def get_case(name: str, **kwargs: Any) -> ComparisonCase: + if name not in _CASE_REGISTRY: + _autodiscover_cases() + if name not in _CASE_REGISTRY: + raise KeyError( + f"Unknown case: {name!r}. Registered: {sorted(_CASE_REGISTRY)}" + ) + return _CASE_REGISTRY[name](**kwargs) + + +def list_cases() -> list[str]: + _autodiscover_cases() + return sorted(_CASE_REGISTRY) diff --git a/llm/testing/hf_compare/cases/__init__.py b/llm/testing/hf_compare/cases/__init__.py new file mode 100644 index 00000000..1d7843bc --- /dev/null +++ b/llm/testing/hf_compare/cases/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Registered comparison cases. + +Every module in this package is auto-imported by the CLI so that each +``@register_case("...")`` decorator runs at import time. +""" diff --git a/llm/testing/hf_compare/cases/_qwen3_14b_common.py b/llm/testing/hf_compare/cases/_qwen3_14b_common.py new file mode 100644 index 00000000..28815af7 --- /dev/null +++ b/llm/testing/hf_compare/cases/_qwen3_14b_common.py @@ -0,0 +1,144 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Shared constants, weight adapter, and kernel loader for Qwen3-14B hf_compare cases.""" +from __future__ import annotations + +import importlib.util +from pathlib import Path +from types import ModuleType +from typing import Any + +import torch + +from ..weight_adapter import Cast, Contiguous, DictAdapter, Map, Transpose, View + + +# --------------------------------------------------------------------------- +# Model constants +# --------------------------------------------------------------------------- +DEFAULT_MODEL_PATH = "/data/linyifan/models/Qwen3-14B" + +HIDDEN = 5120 +HEAD_DIM = 128 +NUM_HEADS = 40 +NUM_KV_HEADS = 8 +INTER = 17408 + + +# --------------------------------------------------------------------------- +# Kernel input ordering. Decode and fused decode share an identical signature; +# prefill reorders seq_lens / paged-attention metadata. +# --------------------------------------------------------------------------- +DECODE_SPEC_ORDER: list[str] = [ + "hidden_states", + "input_rms_weight", + "wq", + "wk", + "wv", + "q_norm_weight", + "k_norm_weight", + "seq_lens", + "block_table", + "slot_mapping", + "rope_cos", + "rope_sin", + "k_cache", + "v_cache", + "wo", + "post_rms_weight", + "w_gate", + "w_up", + "w_down", + "out", +] + +PREFILL_SPEC_ORDER: list[str] = [ + "hidden_states", + "seq_lens", + "input_rms_weight", + "wq", + "wk", + "wv", + "q_norm_weight", + "k_norm_weight", + "rope_cos", + "rope_sin", + "block_table", + "slot_mapping", + "k_cache", + "v_cache", + "wo", + "post_rms_weight", + "w_gate", + "w_up", + "w_down", + "out", +] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def coerce_bool(v: Any) -> bool: + if isinstance(v, bool): + return v + return str(v).lower() in {"1", "true", "yes", "on"} + + +def load_kernel_module(rel_path: str, module_name: str) -> ModuleType: + """Load a kernel ``.py`` file by repo-relative path. + + ``rel_path`` is resolved against the pypto-lib repo root (four levels up + from this file). ``module_name`` is the sys.modules key the loaded module + is bound to. + """ + repo_root = Path(__file__).resolve().parents[4] + kernel_path = repo_root / rel_path + spec = importlib.util.spec_from_file_location(module_name, kernel_path) + if spec is None or spec.loader is None: + raise ImportError(f"Cannot load kernel from {kernel_path}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def single_layer_adapter(layer_idx: int = 0) -> DictAdapter: + """DictAdapter mapping HF Qwen3-14B layer weights to kernel inputs. + + Identical layout is consumed by the single-layer decode and prefill + kernels; pick a different ``layer_idx`` to source weights from a non-zero + HF layer. + """ + return DictAdapter( + prefix=f"model.layers.{layer_idx}.", + mapping={ + "input_rms_weight": Map("input_layernorm.weight", + ops=[View([1, HIDDEN]), Cast(torch.float32)]), + "wq": Map("self_attn.q_proj.weight", + ops=[Transpose(), Contiguous(), Cast(torch.bfloat16)]), + "wk": Map("self_attn.k_proj.weight", + ops=[Transpose(), Contiguous(), Cast(torch.bfloat16)]), + "wv": Map("self_attn.v_proj.weight", + ops=[Transpose(), Contiguous(), Cast(torch.bfloat16)]), + "q_norm_weight": Map("self_attn.q_norm.weight", + ops=[View([1, HEAD_DIM]), Cast(torch.float32)]), + "k_norm_weight": Map("self_attn.k_norm.weight", + ops=[View([1, HEAD_DIM]), Cast(torch.float32)]), + "wo": Map("self_attn.o_proj.weight", + ops=[Transpose(), Contiguous(), Cast(torch.bfloat16)]), + "post_rms_weight": Map("post_attention_layernorm.weight", + ops=[View([1, HIDDEN]), Cast(torch.float32)]), + "w_gate": Map("mlp.gate_proj.weight", + ops=[Transpose(), Contiguous(), Cast(torch.bfloat16)]), + "w_up": Map("mlp.up_proj.weight", + ops=[Transpose(), Contiguous(), Cast(torch.bfloat16)]), + "w_down": Map("mlp.down_proj.weight", + ops=[Transpose(), Contiguous(), Cast(torch.bfloat16)]), + }, + ) diff --git a/llm/testing/hf_compare/cases/qwen3_14b_decode.py b/llm/testing/hf_compare/cases/qwen3_14b_decode.py new file mode 100644 index 00000000..faba0e59 --- /dev/null +++ b/llm/testing/hf_compare/cases/qwen3_14b_decode.py @@ -0,0 +1,288 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Comparison case: Qwen3-14B decode (single layer) vs Hugging Face Qwen3DecoderLayer. + +Port of ``examples/models/qwen3/14b/compare_with_hf.py`` onto the +``llm.testing.hf_compare`` framework. Run via: + + python -m llm.testing.hf_compare run qwen3_14b.decode \\ + -k hf_model_path=/data/linyifan/models/Qwen3-14B \\ + -k platform=a2a3 -k batch=16 + # CPU-only (use the script's golden_qwen3_decode as target): + python -m llm.testing.hf_compare run qwen3_14b.decode -k cpu_only=true + + # For a shorter context length (alias: ctx_len): + python -m llm.testing.hf_compare run qwen3_14b.decode -k cpu_only=true -k seq_len=128 +""" +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +import torch + +from ..base import ComparisonCase, InputSpec, OutputSelector, TensorSpec, Tolerance, register_case +from ..input_sampler import build_rope_tables, int_fill, uniform +from ..paged_kv import ( + compute_decode_slot_mapping, + init_block_table_identity, + paged_to_dense_history, + select_kv_at_decode_slot, +) +from ..reference import CallableReference +from ..target import CallableTarget, PyPTOKernelTarget +from ._qwen3_14b_common import ( + DECODE_SPEC_ORDER as SPEC_ORDER, + DEFAULT_MODEL_PATH, + HEAD_DIM, + HIDDEN, + NUM_KV_HEADS, + coerce_bool, + load_kernel_module, + single_layer_adapter, +) + + +# --------------------------------------------------------------------------- +# HF reference forward +# --------------------------------------------------------------------------- +def _hf_decode_forward( + inputs: Mapping[str, torch.Tensor], + hf_state: Mapping[str, torch.Tensor], + *, + model_path: str, + max_seq: int, + block_size: int, + hf_dtype: torch.dtype, +) -> dict[str, torch.Tensor]: + """Run Qwen3DecoderLayer for layer 0 with the given inputs / KV history.""" + from transformers import AutoConfig + from transformers.cache_utils import DynamicCache + from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer + + cfg = AutoConfig.from_pretrained(model_path) + cfg._attn_implementation = "eager" + layer = Qwen3DecoderLayer(cfg, layer_idx=0).to(hf_dtype).eval() + + prefix = "model.layers.0." + sd = {k[len(prefix):]: v.to(hf_dtype) for k, v in hf_state.items() if k.startswith(prefix)} + missing, unexpected = layer.load_state_dict(sd, strict=False) + if missing: + raise RuntimeError(f"[qwen3.decode] missing HF keys: {missing}") + if unexpected: + raise RuntimeError(f"[qwen3.decode] unexpected HF keys: {unexpected}") + + batch = inputs["hidden_states"].shape[0] + seq_len = int(inputs["seq_lens"][0].item()) + pos = seq_len - 1 + + block_table = inputs["block_table"] + max_blocks_per_seq = (max_seq + block_size - 1) // block_size + k_hist, v_hist = paged_to_dense_history( + inputs["k_cache"], inputs["v_cache"], block_table, + batch=batch, ctx_len=pos, + num_kv_heads=NUM_KV_HEADS, head_dim=HEAD_DIM, + block_size=block_size, max_blocks_per_seq=max_blocks_per_seq, + hf_dtype=hf_dtype, + ) + + cache = DynamicCache() + cache.update(k_hist, v_hist, layer_idx=0) + + cos = inputs["rope_cos"][pos:pos + 1].to(hf_dtype).unsqueeze(0).expand(batch, -1, -1) + sin = inputs["rope_sin"][pos:pos + 1].to(hf_dtype).unsqueeze(0).expand(batch, -1, -1) + hs_in = inputs["hidden_states"].to(hf_dtype).unsqueeze(1) + pos_ids = torch.full((batch, 1), pos, dtype=torch.long) + + with torch.no_grad(): + out = layer( + hidden_states=hs_in, + attention_mask=None, + position_ids=pos_ids, + past_key_values=cache, + position_embeddings=(cos, sin), + ) + if isinstance(out, tuple): + out = out[0] + out_hs = out.squeeze(1) + + new_k = cache.layers[0].keys[:, :, pos:pos + 1, :].squeeze(2) # [B, nkv, head_dim] + new_v = cache.layers[0].values[:, :, pos:pos + 1, :].squeeze(2) + + return {"layer_out": out_hs, "k_pos": new_k, "v_pos": new_v} + + +# --------------------------------------------------------------------------- +# Case factory +# --------------------------------------------------------------------------- +@register_case("qwen3_14b.decode") +def build( + hf_model_path: str = DEFAULT_MODEL_PATH, + cpu_only: Any = False, + platform: str = "a2a3", + device_id: int | str | None = None, + batch: int = 1, + seed: int = 0, + seq_len: int | None = None, + max_seq: int | None = None, + ctx_len: int | None = None, + atol: float = 5e-3, + rtol: float = 5e-3, + hf_dtype: str = "fp32", +) -> ComparisonCase: + cpu_only = coerce_bool(cpu_only) + device_id = int(device_id) if device_id is not None else None + batch = int(batch) + seed = int(seed) + seq_len = int(seq_len) if seq_len is not None else None + max_seq = int(max_seq) if max_seq is not None else None + ctx_len = int(ctx_len) if ctx_len is not None else None + atol = float(atol) + rtol = float(rtol) + hf_dtype_t = torch.float32 if hf_dtype == "fp32" else torch.bfloat16 + + dec = load_kernel_module( + "models/qwen3/14b/qwen3_14b_decode.py", + "_qwen3_14b_decode_kernel", + ) + + if batch <= 0: + raise ValueError(f"batch must be positive, got {batch}") + # Decode is a single-token step that assumes position = (seq_len - 1). + # `seq_len` is the context length for this step; `max_seq` controls the + # cache capacity / RoPE table size (defaults to the kernel's MAX_SEQ). + # Back-compat: `ctx_len` is an alias for `seq_len` if `seq_len` is unset. + context_len = seq_len if seq_len is not None else (ctx_len or dec.MAX_SEQ) + max_seq_eff = max_seq or dec.MAX_SEQ + if context_len > max_seq_eff: + raise ValueError(f"seq_len={context_len} > max_seq={max_seq_eff}") + + max_blocks_per_seq = (max_seq_eff + dec.BLOCK_SIZE - 1) // dec.BLOCK_SIZE + num_blocks = batch * max_blocks_per_seq + cache_rows = num_blocks * NUM_KV_HEADS * dec.BLOCK_SIZE + + # ---- inputs ------------------------------------------------------------ + cos_tab, sin_tab = build_rope_tables(max_seq_eff, HEAD_DIM, base=1_000_000.0) + input_spec = InputSpec( + seed=seed, + tensors={ + "hidden_states": TensorSpec((batch, HIDDEN), torch.bfloat16, sampler=uniform()), + "seq_lens": TensorSpec((batch,), torch.int32, sampler=int_fill(context_len)), + "block_table": TensorSpec((batch * max_blocks_per_seq,), torch.int32), + "slot_mapping": TensorSpec((batch,), torch.int32), + "k_cache": TensorSpec((cache_rows, HEAD_DIM), torch.bfloat16, sampler=uniform()), + "v_cache": TensorSpec((cache_rows, HEAD_DIM), torch.bfloat16, sampler=uniform()), + "out": TensorSpec((batch, HIDDEN), torch.bfloat16), + # RoPE tables are deterministic; inject via constant samplers. + "rope_cos": TensorSpec(tuple(cos_tab.shape), torch.float32, + sampler=lambda s, d, g: cos_tab.clone().to(d)), + "rope_sin": TensorSpec(tuple(sin_tab.shape), torch.float32, + sampler=lambda s, d, g: sin_tab.clone().to(d)), + }, + ) + + # ---- weight adapter ---------------------------------------------------- + adapter = single_layer_adapter(layer_idx=0) + + # ---- reference --------------------------------------------------------- + reference = CallableReference( + name="hf.Qwen3DecoderLayer", + fn=lambda inp, st: _hf_decode_forward( + inp, st, model_path=hf_model_path, max_seq=max_seq_eff, + block_size=int(dec.BLOCK_SIZE), hf_dtype=hf_dtype_t, + ), + ) + + # ---- target ------------------------------------------------------------ + def _post_run(tensors: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + kv_kwargs = dict( + num_kv_heads=NUM_KV_HEADS, head_dim=HEAD_DIM, + block_size=int(dec.BLOCK_SIZE), + ) + return { + "layer_out": tensors["out"].clone(), + "k_pos": select_kv_at_decode_slot( + tensors["k_cache"], tensors["slot_mapping"], **kv_kwargs, + ), + "v_pos": select_kv_at_decode_slot( + tensors["v_cache"], tensors["slot_mapping"], **kv_kwargs, + ), + } + + if cpu_only: + # Use the script's PyTorch golden as target (HF vs golden_qwen3_decode). + def _golden(inp: Mapping[str, torch.Tensor], _w: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: + ref_inputs = {**_w, **inp} + dec.golden_qwen3_decode(ref_inputs) + kv_kwargs = dict( + num_kv_heads=NUM_KV_HEADS, head_dim=HEAD_DIM, + block_size=int(dec.BLOCK_SIZE), + ) + return { + "layer_out": ref_inputs["out"].clone(), + "k_pos": select_kv_at_decode_slot( + ref_inputs["k_cache"], ref_inputs["slot_mapping"], **kv_kwargs, + ), + "v_pos": select_kv_at_decode_slot( + ref_inputs["v_cache"], ref_inputs["slot_mapping"], **kv_kwargs, + ), + } + target = CallableTarget(name="pytorch.golden_qwen3_decode", fn=_golden) + # Note: the script's golden does not write back to the input dict, so + # k_pos / v_pos comparisons are degenerate. Restrict to layer_out. + selectors = [ + OutputSelector(name="layer_out", ref_key="layer_out", tgt_key="layer_out"), + ] + else: + target = PyPTOKernelTarget( + name=f"pypto.qwen3_14b_decode[{platform}]", + build_program=lambda: dec.build_qwen3_decode_program(batch=batch, max_seq=max_seq_eff), + spec_order=SPEC_ORDER, + platform=platform, + device_id=device_id, + post_run=_post_run, + ) + selectors = [ + OutputSelector(name="layer_out", ref_key="layer_out", tgt_key="layer_out"), + OutputSelector(name="k_cache[pos]", ref_key="k_pos", tgt_key="k_pos"), + OutputSelector(name="v_cache[pos]", ref_key="v_pos", tgt_key="v_pos"), + ] + + return ComparisonCase( + name="qwen3_14b.decode", + reference=reference, + target=target, + input_spec=input_spec, + weight_adapter=adapter, + selectors=selectors, + tolerance=Tolerance(atol=atol, rtol=rtol), + hf_weights=hf_model_path, # path -> auto-loaded by runner + on_inputs=lambda t: _init_paged_attention_decode_inputs( + t, + batch=batch, + block_size=int(dec.BLOCK_SIZE), + max_blocks_per_seq=max_blocks_per_seq, + ), + ) + + +def _init_paged_attention_decode_inputs( + tensors: dict[str, torch.Tensor], + *, + batch: int, + block_size: int, + max_blocks_per_seq: int, +) -> None: + init_block_table_identity( + tensors["block_table"], batch=batch, max_blocks_per_seq=max_blocks_per_seq, + ) + compute_decode_slot_mapping( + tensors["slot_mapping"], tensors["seq_lens"], + batch=batch, block_size=block_size, max_blocks_per_seq=max_blocks_per_seq, + ) diff --git a/llm/testing/hf_compare/cases/qwen3_14b_decode_full.py b/llm/testing/hf_compare/cases/qwen3_14b_decode_full.py new file mode 100644 index 00000000..e2a9a388 --- /dev/null +++ b/llm/testing/hf_compare/cases/qwen3_14b_decode_full.py @@ -0,0 +1,320 @@ +"""Comparison case: Qwen3-14B fused multi-layer decode vs HF. + +Stacks weights across `num_layers` and compares the fused pypto kernel +output against `num_layers` HF Qwen3DecoderLayer instances in sequence. + +Run via: + + python -m llm.testing.hf_compare run qwen3_14b.decode_full \\ + -k hf_model_path=/data/linyifan/models/Qwen3-14B \\ + -k platform=a2a3 -k num_layers=4 -k seq_len=128 +""" +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +import torch + +from ..base import ComparisonCase, InputSpec, OutputSelector, TensorSpec, Tolerance, register_case +from ..base import WeightAdapter +from ..input_sampler import build_rope_tables, int_fill, uniform +from ..paged_kv import ( + compute_decode_slot_mapping, + init_block_table_identity, + paged_to_dense_history, + select_kv_at_decode_slot, +) +from ..reference import CallableReference +from ..target import PyPTOKernelTarget +from ._qwen3_14b_common import ( + DECODE_SPEC_ORDER as SPEC_ORDER, + DEFAULT_MODEL_PATH, + HEAD_DIM, + HIDDEN, + INTER, + NUM_HEADS, + NUM_KV_HEADS, + load_kernel_module, +) + + +# --------------------------------------------------------------------------- +# HF reference: run num_layers Qwen3DecoderLayer in sequence. +# --------------------------------------------------------------------------- +def _hf_decode_full_forward( + inputs: Mapping[str, torch.Tensor], + hf_state: Mapping[str, torch.Tensor], + *, + model_path: str, + max_seq: int, + num_layers: int, + block_size: int, + hf_dtype: torch.dtype, +) -> dict[str, torch.Tensor]: + from transformers import AutoConfig + from transformers.cache_utils import DynamicCache + from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer + + cfg = AutoConfig.from_pretrained(model_path) + cfg._attn_implementation = "eager" + + layers = [] + for li in range(num_layers): + layer = Qwen3DecoderLayer(cfg, layer_idx=li).to(hf_dtype).eval() + prefix = f"model.layers.{li}." + sd = {k[len(prefix):]: v.to(hf_dtype) for k, v in hf_state.items() if k.startswith(prefix)} + missing, unexpected = layer.load_state_dict(sd, strict=False) + if missing: + raise RuntimeError(f"[decode_full] L{li} missing HF keys: {missing}") + if unexpected: + raise RuntimeError(f"[decode_full] L{li} unexpected HF keys: {unexpected}") + layers.append(layer) + + batch = inputs["hidden_states"].shape[0] + seq_len = int(inputs["seq_lens"][0].item()) + pos = seq_len - 1 + + block_table = inputs["block_table"] + max_blocks_per_seq = (max_seq + block_size - 1) // block_size + layer_cache_rows = batch * max_blocks_per_seq * NUM_KV_HEADS * block_size + + cache = DynamicCache() + for li in range(num_layers): + k_hist, v_hist = paged_to_dense_history( + inputs["k_cache"], inputs["v_cache"], block_table, + batch=batch, ctx_len=pos, + num_kv_heads=NUM_KV_HEADS, head_dim=HEAD_DIM, + block_size=block_size, max_blocks_per_seq=max_blocks_per_seq, + hf_dtype=hf_dtype, + layer_offset_rows=li * layer_cache_rows, + ) + cache.update(k_hist, v_hist, layer_idx=li) + + cos = inputs["rope_cos"][pos:pos + 1].to(hf_dtype).unsqueeze(0).expand(batch, -1, -1) + sin = inputs["rope_sin"][pos:pos + 1].to(hf_dtype).unsqueeze(0).expand(batch, -1, -1) + pos_ids = torch.full((batch, 1), pos, dtype=torch.long) + + hs = inputs["hidden_states"].to(hf_dtype).unsqueeze(1) + per_layer_out: list[torch.Tensor] = [] + per_layer_k: list[torch.Tensor] = [] + per_layer_v: list[torch.Tensor] = [] + with torch.no_grad(): + for li in range(num_layers): + out = layers[li]( + hidden_states=hs, + attention_mask=None, + position_ids=pos_ids, + past_key_values=cache, + position_embeddings=(cos, sin), + ) + if isinstance(out, tuple): + out = out[0] + hs = out + per_layer_out.append(hs.squeeze(1).clone()) + per_layer_k.append(cache.layers[li].keys[:, :, pos:pos + 1, :].squeeze(2).clone()) + per_layer_v.append(cache.layers[li].values[:, :, pos:pos + 1, :].squeeze(2).clone()) + + result: dict[str, torch.Tensor] = { + "layer_out": hs.squeeze(1), + } + for li in range(num_layers): + result[f"layer{li:02d}_hidden"] = per_layer_out[li] + result[f"layer{li:02d}_k_pos"] = per_layer_k[li] + result[f"layer{li:02d}_v_pos"] = per_layer_v[li] + return result + + +# --------------------------------------------------------------------------- +# Stacked weight adapter. +# --------------------------------------------------------------------------- +class StackedAdapter(WeightAdapter): + """Stacks per-layer HF weights into the fused kernel layout.""" + + def __init__(self, num_layers: int, hidden: int, head_dim: int): + self.num_layers = num_layers + self.hidden = hidden + self.head_dim = head_dim + + def adapt(self, hf_state: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: + L = self.num_layers + H = self.hidden + D = self.head_dim + + def get(key_template: str, li: int) -> torch.Tensor: + return hf_state[f"model.layers.{li}.{key_template}"] + + def stack_norm(name: str, dim_per_layer: int) -> torch.Tensor: + rows = [get(name, li).view(1, dim_per_layer).float() for li in range(L)] + return torch.cat(rows, dim=0).contiguous() + + def stack_proj(name: str, out_dim: int) -> torch.Tensor: + # HF stores weights as [out, in]; kernel expects [in, out] (transposed). + rows = [get(name, li).t().contiguous().to(torch.bfloat16) for li in range(L)] + return torch.cat(rows, dim=0).contiguous() + + return { + "input_rms_weight": stack_norm("input_layernorm.weight", H), + "wq": stack_proj("self_attn.q_proj.weight", NUM_HEADS * D), + "wk": stack_proj("self_attn.k_proj.weight", NUM_KV_HEADS * D), + "wv": stack_proj("self_attn.v_proj.weight", NUM_KV_HEADS * D), + "q_norm_weight": stack_norm("self_attn.q_norm.weight", D), + "k_norm_weight": stack_norm("self_attn.k_norm.weight", D), + "wo": stack_proj("self_attn.o_proj.weight", H), + "post_rms_weight": stack_norm("post_attention_layernorm.weight", H), + "w_gate": stack_proj("mlp.gate_proj.weight", INTER), + "w_up": stack_proj("mlp.up_proj.weight", INTER), + "w_down": stack_proj("mlp.down_proj.weight", H), + } + + +# --------------------------------------------------------------------------- +# Case factory +# --------------------------------------------------------------------------- +@register_case("qwen3_14b.decode_full") +def build( + hf_model_path: str = DEFAULT_MODEL_PATH, + platform: str = "a2a3", + device_id: int | str | None = None, + batch: int = 1, + seed: int = 0, + seq_len: int | None = None, + max_seq: int | None = None, + num_layers: int = 4, + atol: float = 1.5e-2, + rtol: float = 1.5e-2, + pass_rate: float = 0.99, + hf_dtype: str = "fp32", +) -> ComparisonCase: + # Tolerance defaults are sized for the BF16-faithful NPU path vs FP32 HF + # reference. K-cache (after K-norm + RoPE) sees ~3e-3 mean / ~0.2 max + # absolute error from intermediate BF16 round-trips on UB boundaries; a + # ~2x BF16 ulp budget (1.5e-2) absorbs that, and 0.99 pass_rate accepts + # the BF16 long-tail outliers. Override via -k atol/-k rtol/-k pass_rate. + device_id = int(device_id) if device_id is not None else None + batch = int(batch) + seed = int(seed) + num_layers = int(num_layers) + seq_len = int(seq_len) if seq_len is not None else None + max_seq = int(max_seq) if max_seq is not None else None + atol = float(atol) + rtol = float(rtol) + pass_rate = float(pass_rate) + hf_dtype_t = torch.float32 if hf_dtype == "fp32" else torch.bfloat16 + + dec = load_kernel_module( + "models/qwen3/14b/qwen3_14b_decode_full.py", + "_qwen3_14b_decode_full_kernel", + ) + + if batch <= 0: + raise ValueError(f"batch must be positive, got {batch}") + context_len = seq_len if seq_len is not None else dec.MAX_SEQ + max_seq_eff = max_seq or dec.MAX_SEQ + if context_len > max_seq_eff: + raise ValueError(f"seq_len={context_len} > max_seq={max_seq_eff}") + + block_size = int(dec.BLOCK_SIZE) + max_blocks_per_seq = (max_seq_eff + block_size - 1) // block_size + layer_cache_rows = batch * max_blocks_per_seq * NUM_KV_HEADS * block_size + cache_rows = num_layers * layer_cache_rows + + cos_tab, sin_tab = build_rope_tables(max_seq_eff, HEAD_DIM, base=1_000_000.0) + input_spec = InputSpec( + seed=seed, + tensors={ + "hidden_states": TensorSpec((batch, HIDDEN), torch.bfloat16, sampler=uniform()), + "seq_lens": TensorSpec((batch,), torch.int32, sampler=int_fill(context_len)), + "block_table": TensorSpec((batch * max_blocks_per_seq,), torch.int32), + "slot_mapping": TensorSpec((batch,), torch.int32), + "k_cache": TensorSpec((cache_rows, HEAD_DIM), torch.bfloat16, sampler=uniform()), + "v_cache": TensorSpec((cache_rows, HEAD_DIM), torch.bfloat16, sampler=uniform()), + "out": TensorSpec((batch, HIDDEN), torch.bfloat16), + "rope_cos": TensorSpec(tuple(cos_tab.shape), torch.float32, + sampler=lambda s, d, g: cos_tab.clone().to(d)), + "rope_sin": TensorSpec(tuple(sin_tab.shape), torch.float32, + sampler=lambda s, d, g: sin_tab.clone().to(d)), + }, + ) + + adapter = StackedAdapter(num_layers=num_layers, hidden=HIDDEN, head_dim=HEAD_DIM) + + reference = CallableReference( + name="hf.Qwen3DecoderLayer_x{}".format(num_layers), + fn=lambda inp, st: _hf_decode_full_forward( + inp, st, + model_path=hf_model_path, max_seq=max_seq_eff, + num_layers=num_layers, block_size=block_size, hf_dtype=hf_dtype_t, + ), + ) + + def _post_run(tensors: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + result: dict[str, torch.Tensor] = {"layer_out": tensors["out"].clone()} + kv_kwargs = dict( + num_kv_heads=NUM_KV_HEADS, head_dim=HEAD_DIM, block_size=block_size, + ) + for li in range(num_layers): + offset = li * layer_cache_rows + result[f"layer{li:02d}_k_pos"] = select_kv_at_decode_slot( + tensors["k_cache"], tensors["slot_mapping"], + layer_offset_rows=offset, **kv_kwargs, + ) + result[f"layer{li:02d}_v_pos"] = select_kv_at_decode_slot( + tensors["v_cache"], tensors["slot_mapping"], + layer_offset_rows=offset, **kv_kwargs, + ) + return result + + target = PyPTOKernelTarget( + name=f"pypto.qwen3_14b_decode_full[{platform}]", + build_program=lambda: dec.build_qwen3_decode_program( + batch=batch, max_seq=max_seq_eff, num_layers=num_layers, + ), + spec_order=SPEC_ORDER, + platform=platform, + device_id=device_id, + post_run=_post_run, + ) + + selectors: list[OutputSelector] = [ + OutputSelector(name="layer_out", ref_key="layer_out", tgt_key="layer_out"), + ] + for li in range(num_layers): + selectors.append(OutputSelector( + name=f"layer{li:02d}.k_pos", + ref_key=f"layer{li:02d}_k_pos", tgt_key=f"layer{li:02d}_k_pos", + )) + selectors.append(OutputSelector( + name=f"layer{li:02d}.v_pos", + ref_key=f"layer{li:02d}_v_pos", tgt_key=f"layer{li:02d}_v_pos", + )) + + return ComparisonCase( + name="qwen3_14b.decode_full", + reference=reference, + target=target, + input_spec=input_spec, + weight_adapter=adapter, + selectors=selectors, + tolerance=Tolerance(atol=atol, rtol=rtol, pass_rate_threshold=pass_rate), + hf_weights=hf_model_path, + on_inputs=lambda t: _init_paged_attention_inputs( + t, batch=batch, block_size=block_size, max_blocks_per_seq=max_blocks_per_seq, + ), + ) + + +def _init_paged_attention_inputs( + tensors: dict[str, torch.Tensor], + *, + batch: int, + block_size: int, + max_blocks_per_seq: int, +) -> None: + init_block_table_identity( + tensors["block_table"], batch=batch, max_blocks_per_seq=max_blocks_per_seq, + ) + compute_decode_slot_mapping( + tensors["slot_mapping"], tensors["seq_lens"], + batch=batch, block_size=block_size, max_blocks_per_seq=max_blocks_per_seq, + ) diff --git a/llm/testing/hf_compare/cases/qwen3_14b_e2e.py b/llm/testing/hf_compare/cases/qwen3_14b_e2e.py new file mode 100644 index 00000000..d20c298c --- /dev/null +++ b/llm/testing/hf_compare/cases/qwen3_14b_e2e.py @@ -0,0 +1,700 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Comparison case: Qwen3-14B end-to-end generation vs Hugging Face. + +Run via: + + python -m llm.testing.hf_compare run qwen3_14b.e2e \ + -k hf_model_path=/data/linyifan/models/Qwen3-14B \ + -k prompt_ids=151644,872,198 \ + -k max_new_tokens=4 \ + -k num_layers=2 \ + -k cpu_only=true +""" +from __future__ import annotations + +import importlib.util +import math +from collections.abc import Mapping +from pathlib import Path +from typing import Any + +import torch + +from ..base import ComparisonCase, InputSpec, OutputSelector, TensorSpec, Tolerance, register_case +from ..reference import CallableReference +from ..target import CallableTarget, _resolve_device_id +from ..weight_adapter import PassthroughAdapter + + +DEFAULT_MODEL_PATH = "/data/linyifan/models/Qwen3-14B" +DEFAULT_PROMPT_IDS = "151644,872,198" +EPS = 1e-6 + + +def _load_qwen3_modules() -> tuple[Any, Any]: + """Load the decode + prefill kernel modules from ``models/qwen3/14b/``.""" + repo_root = Path(__file__).resolve().parents[4] + + def _load(name: str, rel_path: str) -> Any: + module_path = repo_root / rel_path + spec = importlib.util.spec_from_file_location(name, module_path) + if spec is None or spec.loader is None: + raise ImportError(f"Cannot load Qwen3-14B helper from {module_path}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + return ( + _load("_qwen3_14b_decode_kernel", "models/qwen3/14b/qwen3_14b_decode.py"), + _load("_qwen3_14b_prefill_kernel", "models/qwen3/14b/qwen3_14b_prefill.py"), + ) + + +def _coerce_bool(v: Any) -> bool: + if isinstance(v, bool): + return v + return str(v).lower() in {"1", "true", "yes", "on"} + + +def _parse_prompt_ids(prompt_ids: str | list[int] | tuple[int, ...]) -> list[int]: + if isinstance(prompt_ids, (list, tuple)): + return [int(v) for v in prompt_ids] + out = [int(part.strip()) for part in str(prompt_ids).split(",") if part.strip()] + if not out: + raise ValueError("prompt_ids must contain at least one token id") + return out + + +def _state_for_num_layers( + hf_state: Mapping[str, torch.Tensor], + *, + num_layers: int, +) -> dict[str, torch.Tensor]: + keep: dict[str, torch.Tensor] = {} + layer_prefixes = tuple(f"model.layers.{idx}." for idx in range(num_layers)) + for key, value in hf_state.items(): + if key.startswith(("model.embed_tokens.", "model.norm.", "lm_head.")): + keep[key] = value + continue + if key.startswith(layer_prefixes): + keep[key] = value + return keep + + +def _extract_layer_weights( + all_weights: Mapping[str, torch.Tensor], + *, + layer_idx: int, + hidden: int, + head_dim: int, +) -> dict[str, torch.Tensor]: + prefix = f"model.layers.{layer_idx}." + + def w(name: str) -> torch.Tensor: + return all_weights[prefix + name] + + def proj(name: str) -> torch.Tensor: + return w(name).t().contiguous().to(torch.bfloat16) + + return { + "input_rms_weight": w("input_layernorm.weight").view(1, hidden).to(torch.float32), + "wq": proj("self_attn.q_proj.weight"), + "wk": proj("self_attn.k_proj.weight"), + "wv": proj("self_attn.v_proj.weight"), + "q_norm_weight": w("self_attn.q_norm.weight").view(1, head_dim).to(torch.float32), + "k_norm_weight": w("self_attn.k_norm.weight").view(1, head_dim).to(torch.float32), + "wo": proj("self_attn.o_proj.weight"), + "post_rms_weight": w("post_attention_layernorm.weight").view(1, hidden).to(torch.float32), + "w_gate": proj("mlp.gate_proj.weight"), + "w_up": proj("mlp.up_proj.weight"), + "w_down": proj("mlp.down_proj.weight"), + } + + +def _build_rope(max_seq: int, head_dim: int, base: float = 1_000_000.0) -> tuple[torch.Tensor, torch.Tensor]: + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)) + pos = torch.arange(max_seq, dtype=torch.float32) + freqs = torch.outer(pos, inv_freq) + emb = torch.cat([freqs, freqs], dim=-1) + return emb.cos().to(torch.float32), emb.sin().to(torch.float32) + + +def _embed_tokens(input_ids: torch.Tensor, embed_weight: torch.Tensor) -> torch.Tensor: + return embed_weight[input_ids].to(torch.bfloat16) + + +def _final_rmsnorm(hidden: torch.Tensor, norm_weight: torch.Tensor) -> torch.Tensor: + x = hidden.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + EPS + return x * torch.rsqrt(variance) * norm_weight.float() + + +def _lm_head_forward(hidden: torch.Tensor, lm_weight: torch.Tensor) -> torch.Tensor: + return torch.matmul(hidden.float(), lm_weight.float().t()) + + +def _greedy_sample(logits: torch.Tensor) -> torch.Tensor: + return logits.argmax(dim=-1) + + +def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + half = x.shape[-1] // 2 + x_lo, x_hi = x[..., :half], x[..., half:] + cos_lo, cos_hi = cos[..., :half], cos[..., half:] + sin_lo, sin_hi = sin[..., :half], sin[..., half:] + return torch.cat([ + x_lo * cos_lo - x_hi * sin_lo, + x_hi * cos_hi + x_lo * sin_hi, + ], dim=-1) + + +def _run_cpu_prefill( + hidden_states: torch.Tensor, + seq_lens: torch.Tensor, + layer_weights: dict[str, torch.Tensor], + rope_cos: torch.Tensor, + rope_sin: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + *, + batch: int, + max_seq: int, + head_dim: int, + eps: float = EPS, +) -> torch.Tensor: + w = layer_weights + hidden_size = w["wq"].shape[0] + kv_hidden = w["wk"].shape[1] + num_heads = hidden_size // head_dim + num_kv_heads = kv_hidden // head_dim + q_per_kv = num_heads // num_kv_heads + scale = 1.0 / math.sqrt(head_dim) + + wq_f = w["wq"].float() + wk_f = w["wk"].float() + wv_f = w["wv"].float() + wo_f = w["wo"].float() + w_gate_f = w["w_gate"].float() + w_up_f = w["w_up"].float() + w_down_f = w["w_down"].float() + rms_w = w["input_rms_weight"].float() + post_rms_w = w["post_rms_weight"].float() + q_norm_w = w["q_norm_weight"].float().view(1, 1, head_dim) + k_norm_w = w["k_norm_weight"].float().view(1, 1, head_dim) + + out = torch.zeros(batch, max_seq, hidden_size, dtype=torch.bfloat16) + + for b in range(batch): + seq = int(seq_lens[b].item()) + if seq <= 0: + continue + + x = hidden_states[b, :seq, :].float() + variance = x.square().mean(dim=-1, keepdim=True) + eps + normed = (x / torch.sqrt(variance) * rms_w).to(torch.bfloat16).float() + q_proj = normed @ wq_f + k_proj = normed @ wk_f + v_proj = normed @ wv_f + + q_view = q_proj.view(seq, num_heads, head_dim) + q_view = q_view * torch.rsqrt(q_view.pow(2).mean(-1, keepdim=True) + eps) * q_norm_w + k_view = k_proj.view(seq, num_kv_heads, head_dim) + k_view = k_view * torch.rsqrt(k_view.pow(2).mean(-1, keepdim=True) + eps) * k_norm_w + + cos = rope_cos[:seq, :].unsqueeze(1) + sin = rope_sin[:seq, :].unsqueeze(1) + k_rot = _apply_rope(k_view, cos, sin).to(torch.bfloat16) + v_bf16 = v_proj.view(seq, num_kv_heads, head_dim).to(torch.bfloat16) + q_rot = _apply_rope(q_view, cos, sin).to(torch.bfloat16) + + cache_off = b * num_kv_heads * max_seq + for ki in range(num_kv_heads): + row0 = cache_off + ki * max_seq + k_cache[row0:row0 + seq, :] = k_rot[:, ki, :] + v_cache[row0:row0 + seq, :] = v_bf16[:, ki, :] + + attn_result = torch.zeros(seq, hidden_size, dtype=torch.float32) + for ki in range(num_kv_heads): + row0 = cache_off + ki * max_seq + k_cached = k_cache[row0:row0 + seq, :] + v_cached = v_cache[row0:row0 + seq, :] + + for qi in range(q_per_kv): + hi = ki * q_per_kv + qi + q_head = q_rot[:, hi, :] + scores = (q_head.float() @ k_cached.float().T) * scale + causal = torch.triu(torch.full((seq, seq), float("-inf")), diagonal=1) + scores = scores + causal + attn_w = torch.softmax(scores, dim=-1) + attn_result[:, hi * head_dim:(hi + 1) * head_dim] = attn_w @ v_cached.float() + + attn_bf16 = attn_result.to(torch.bfloat16).float() + hs = hidden_states[b, :seq, :].float() + resid1 = torch.matmul(attn_bf16, wo_f) + hs + variance = resid1.pow(2).mean(-1, keepdim=True) + post_normed = (resid1 * torch.rsqrt(variance + eps) * post_rms_w).bfloat16().float() + gate = torch.matmul(post_normed, w_gate_f) + up = torch.matmul(post_normed, w_up_f) + mlp = (gate * torch.sigmoid(gate) * up).bfloat16().float() + down = torch.matmul(mlp, w_down_f) + out[b, :seq, :] = (down + resid1).bfloat16() + + return out + + +def _run_cpu_decode( + hidden_states: torch.Tensor, + seq_lens: torch.Tensor, + layer_weights: dict[str, torch.Tensor], + rope_cos: torch.Tensor, + rope_sin: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + *, + batch: int, + head_dim: int, + eps: float = EPS, +) -> torch.Tensor: + w = layer_weights + hidden_size = w["wq"].shape[0] + kv_hidden = w["wk"].shape[1] + num_heads = hidden_size // head_dim + num_kv_heads = kv_hidden // head_dim + q_per_kv = num_heads // num_kv_heads + max_seq = rope_cos.shape[0] + scale = 1.0 / math.sqrt(head_dim) + + wq_f = w["wq"].float() + wk_f = w["wk"].float() + wv_f = w["wv"].float() + wo_f = w["wo"].float() + w_gate_f = w["w_gate"].float() + w_up_f = w["w_up"].float() + w_down_f = w["w_down"].float() + rms_w = w["input_rms_weight"].float() + post_rms_w = w["post_rms_weight"].float() + q_norm_w = w["q_norm_weight"].float() + k_norm_w = w["k_norm_weight"].float() + + out = torch.zeros(batch, hidden_size, dtype=torch.bfloat16) + + for b in range(batch): + ctx_len = int(seq_lens[b].item()) + pos = ctx_len - 1 + x = hidden_states[b:b + 1, :].float() + variance = x.square().mean(dim=-1, keepdim=True) + eps + normed = (x / torch.sqrt(variance) * rms_w).to(torch.bfloat16).float() + q_proj = normed @ wq_f + k_proj = normed @ wk_f + v_proj = normed @ wv_f + + q_heads = q_proj.view(num_heads, head_dim) + q_heads = q_heads * torch.rsqrt(q_heads.pow(2).mean(-1, keepdim=True) + eps) * q_norm_w + k_heads = k_proj.view(num_kv_heads, head_dim) + k_heads = k_heads * torch.rsqrt(k_heads.pow(2).mean(-1, keepdim=True) + eps) * k_norm_w + + cos = rope_cos[pos:pos + 1, :] + sin = rope_sin[pos:pos + 1, :] + k_rot = _apply_rope(k_heads, cos, sin).to(torch.bfloat16) + q_rot = _apply_rope(q_heads, cos, sin).to(torch.bfloat16) + v_bf16 = v_proj.view(num_kv_heads, head_dim).to(torch.bfloat16) + + for ki in range(num_kv_heads): + cr = b * num_kv_heads * max_seq + ki * max_seq + pos + k_cache[cr, :] = k_rot[ki] + v_cache[cr, :] = v_bf16[ki] + + attn_result = torch.zeros(1, hidden_size, dtype=torch.float32) + for ki in range(num_kv_heads): + row0 = b * num_kv_heads * max_seq + ki * max_seq + k_ctx = k_cache[row0:row0 + ctx_len, :] + v_ctx = v_cache[row0:row0 + ctx_len, :] + + for qi in range(q_per_kv): + hi = ki * q_per_kv + qi + q_head = q_rot[hi:hi + 1, :] + scores = (q_head.float() @ k_ctx.float().T) * scale + attn_w = torch.softmax(scores, dim=-1) + attn_result[:, hi * head_dim:(hi + 1) * head_dim] = attn_w @ v_ctx.float() + + attn_bf16 = attn_result.to(torch.bfloat16).float() + hs = hidden_states[b:b + 1, :].float() + resid1 = torch.matmul(attn_bf16, wo_f) + hs + variance = resid1.pow(2).mean(-1, keepdim=True) + post_normed = (resid1 * torch.rsqrt(variance + eps) * post_rms_w).bfloat16().float() + gate = torch.matmul(post_normed, w_gate_f) + up = torch.matmul(post_normed, w_up_f) + mlp = (gate * torch.sigmoid(gate) * up).bfloat16().float() + down = torch.matmul(mlp, w_down_f) + out[b:b + 1, :] = (down + resid1).bfloat16() + + return out + + +class _NPURunner: + def __init__(self, dec: Any, prefill: Any, *, batch: int, max_seq: int, platform: str, device_id: int): + from pypto import ir + from pypto.backend import BackendType + from pypto.runtime import execute_compiled + + backend_map = { + "a2a3": BackendType.Ascend910B, + "a2a3sim": BackendType.Ascend910B, + "a5": BackendType.Ascend950, + "a5sim": BackendType.Ascend950, + } + self._execute = execute_compiled + self._platform = platform + self._device_id = device_id + backend = backend_map[platform] + self._prefill_spec_order = [ + "hidden_states", "seq_lens", "input_rms_weight", + "wq", "wk", "wv", "q_norm_weight", "k_norm_weight", + "rope_cos", "rope_sin", "k_cache", "v_cache", + "wo", "post_rms_weight", "w_gate", "w_up", "w_down", "out", + ] + self._decode_spec_order = [ + "hidden_states", "input_rms_weight", + "wq", "wk", "wv", "q_norm_weight", "k_norm_weight", + "seq_lens", "rope_cos", "rope_sin", "k_cache", "v_cache", + "wo", "post_rms_weight", "w_gate", "w_up", "w_down", "out", + ] + + compiled_prefill = ir.compile(prefill.build_qwen3_14b_prefill_program(batch=batch, max_seq=max_seq), backend_type=backend) + self._prefill_dir = compiled_prefill.output_dir + compiled_decode = ir.compile(dec.build_qwen3_decode_program(batch=batch, max_seq=max_seq), backend_type=backend) + self._decode_dir = compiled_decode.output_dir + + def run_prefill( + self, + hidden_states: torch.Tensor, + seq_lens: torch.Tensor, + layer_weights: dict[str, torch.Tensor], + rope_cos: torch.Tensor, + rope_sin: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + *, + batch: int, + max_seq: int, + ) -> torch.Tensor: + hidden = layer_weights["wq"].shape[0] + out = torch.zeros(batch, max_seq, hidden, dtype=torch.bfloat16) + tensors = { + "hidden_states": hidden_states, + "seq_lens": seq_lens, + "rope_cos": rope_cos, + "rope_sin": rope_sin, + "k_cache": k_cache, + "v_cache": v_cache, + "out": out, + **layer_weights, + } + ordered = [tensors[name] for name in self._prefill_spec_order] + self._execute(self._prefill_dir, ordered, platform=self._platform, device_id=self._device_id) + return tensors["out"] + + def run_decode( + self, + hidden_states: torch.Tensor, + seq_lens: torch.Tensor, + layer_weights: dict[str, torch.Tensor], + rope_cos: torch.Tensor, + rope_sin: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + *, + batch: int, + ) -> torch.Tensor: + hidden = layer_weights["wq"].shape[0] + out = torch.zeros(batch, hidden, dtype=torch.bfloat16) + tensors = { + "hidden_states": hidden_states, + "seq_lens": seq_lens, + "rope_cos": rope_cos, + "rope_sin": rope_sin, + "k_cache": k_cache, + "v_cache": v_cache, + "out": out, + **layer_weights, + } + ordered = [tensors[name] for name in self._decode_spec_order] + self._execute(self._decode_dir, ordered, platform=self._platform, device_id=self._device_id) + return tensors["out"] + + +def _hf_e2e_forward( + inputs: Mapping[str, torch.Tensor], + hf_state: Mapping[str, torch.Tensor], + *, + model_path: str, + num_layers: int, + max_new_tokens: int, + hf_dtype: torch.dtype, +) -> dict[str, torch.Tensor]: + from transformers import AutoConfig + from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM + + cfg = AutoConfig.from_pretrained(model_path) + cfg._attn_implementation = "eager" + cfg.num_hidden_layers = num_layers + + model = Qwen3ForCausalLM(cfg).to(hf_dtype).eval() + sd = {k: v.to(hf_dtype) for k, v in _state_for_num_layers(hf_state, num_layers=num_layers).items()} + missing, unexpected = model.load_state_dict(sd, strict=False) + if unexpected: + raise RuntimeError(f"[qwen3.e2e] unexpected HF keys: {unexpected}") + missing = [key for key in missing if "rotary_emb.inv_freq" not in key] + if missing: + raise RuntimeError(f"[qwen3.e2e] missing HF keys: {missing}") + + input_ids = inputs["input_ids"].to(torch.long) + with torch.no_grad(): + out = model(input_ids=input_ids, use_cache=True) + logits = out.logits[:, -1, :].float() + next_token = logits.argmax(dim=-1) + generated: list[int] = [int(next_token.item())] + past_key_values = out.past_key_values + + for _ in range(1, max_new_tokens): + out = model(input_ids=next_token.view(1, 1), use_cache=True, past_key_values=past_key_values) + logits = out.logits[:, -1, :].float() + next_token = logits.argmax(dim=-1) + generated.append(int(next_token.item())) + past_key_values = out.past_key_values + + return { + "generated_ids": torch.tensor([generated], dtype=torch.int64), + "last_logits": logits, + } + + +def _run_e2e_target( + inputs: Mapping[str, torch.Tensor], + weights: Mapping[str, torch.Tensor], + *, + max_new_tokens: int, + num_layers: int, + max_seq: int, + cpu_only: bool, + platform: str, + device_id: int | str | None, +) -> dict[str, torch.Tensor]: + dec, prefill = _load_qwen3_modules() + + input_ids = inputs["input_ids"].to(torch.long) + batch = int(input_ids.shape[0]) + if batch != 1: + raise ValueError(f"qwen3_14b.e2e expects batch=1, got {batch}") + + prompt_len = int(input_ids.shape[1]) + if prompt_len + max_new_tokens - 1 > max_seq: + raise ValueError( + f"prompt_len + max_new_tokens - 1 = {prompt_len + max_new_tokens - 1} exceeds max_seq={max_seq}" + ) + + hidden = int(prefill.HIDDEN) + head_dim = int(prefill.HEAD_DIM) + num_kv_heads = int(prefill.NUM_KV_HEADS) + cache_rows = batch * num_kv_heads * max_seq + + layer_weights = [ + _extract_layer_weights(dict(weights), layer_idx=idx, hidden=hidden, head_dim=head_dim) + for idx in range(num_layers) + ] + embed_weight = weights["model.embed_tokens.weight"] + norm_weight = weights["model.norm.weight"].view(1, hidden) + lm_weight = weights["lm_head.weight"] + rope_cos, rope_sin = _build_rope(max_seq, head_dim) + + kv_caches = [ + { + "k_cache": torch.zeros(cache_rows, head_dim, dtype=torch.bfloat16), + "v_cache": torch.zeros(cache_rows, head_dim, dtype=torch.bfloat16), + } + for _ in range(num_layers) + ] + + resolved_device_id = _resolve_device_id(device_id) + npu_runner = None if cpu_only else _NPURunner( + dec, + prefill, + batch=batch, + max_seq=max_seq, + platform=platform, + device_id=resolved_device_id, + ) + + hidden_states_full = torch.zeros(batch, max_seq, hidden, dtype=torch.bfloat16) + hidden_states_full[:, :prompt_len, :] = _embed_tokens(input_ids, embed_weight) + seq_lens = torch.full((batch,), prompt_len, dtype=torch.int32) + + for layer_idx in range(num_layers): + if cpu_only: + hidden_states_full = _run_cpu_prefill( + hidden_states_full, + seq_lens, + layer_weights[layer_idx], + rope_cos, + rope_sin, + kv_caches[layer_idx]["k_cache"], + kv_caches[layer_idx]["v_cache"], + batch=batch, + max_seq=max_seq, + head_dim=head_dim, + ) + else: + hidden_states_full = npu_runner.run_prefill( + hidden_states_full, + seq_lens, + layer_weights[layer_idx], + rope_cos, + rope_sin, + kv_caches[layer_idx]["k_cache"], + kv_caches[layer_idx]["v_cache"], + batch=batch, + max_seq=max_seq, + ) + + last_hidden = hidden_states_full[:, prompt_len - 1, :].float() + logits = _lm_head_forward(_final_rmsnorm(last_hidden, norm_weight), lm_weight).float() + next_token = _greedy_sample(logits) + generated = [int(next_token.item())] + cur_seq_len = prompt_len + 1 + + for _ in range(1, max_new_tokens): + hidden_states = _embed_tokens(next_token.unsqueeze(0), embed_weight).squeeze(1) + decode_seq_lens = torch.full((batch,), cur_seq_len, dtype=torch.int32) + + for layer_idx in range(num_layers): + if cpu_only: + hidden_states = _run_cpu_decode( + hidden_states, + decode_seq_lens, + layer_weights[layer_idx], + rope_cos, + rope_sin, + kv_caches[layer_idx]["k_cache"], + kv_caches[layer_idx]["v_cache"], + batch=batch, + head_dim=head_dim, + ) + else: + hidden_states = npu_runner.run_decode( + hidden_states, + decode_seq_lens, + layer_weights[layer_idx], + rope_cos, + rope_sin, + kv_caches[layer_idx]["k_cache"], + kv_caches[layer_idx]["v_cache"], + batch=batch, + ) + + logits = _lm_head_forward(_final_rmsnorm(hidden_states.float(), norm_weight), lm_weight).float() + next_token = _greedy_sample(logits) + generated.append(int(next_token.item())) + cur_seq_len += 1 + + return { + "generated_ids": torch.tensor([generated], dtype=torch.int64), + "last_logits": logits, + } + + +@register_case("qwen3_14b.e2e") +def build( + hf_model_path: str = DEFAULT_MODEL_PATH, + prompt_ids: str = DEFAULT_PROMPT_IDS, + cpu_only: Any = True, + platform: str = "a2a3", + device_id: int | str | None = None, + num_layers: int = 2, + max_new_tokens: int = 4, + max_seq: int = 256, + atol: float = 5e-3, + rtol: float = 5e-3, + hf_dtype: str = "fp32", +) -> ComparisonCase: + cpu_only = _coerce_bool(cpu_only) + device_id = int(device_id) if device_id is not None else None + num_layers = int(num_layers) + max_new_tokens = int(max_new_tokens) + max_seq = int(max_seq) + atol = float(atol) + rtol = float(rtol) + hf_dtype_t = torch.float32 if hf_dtype == "fp32" else torch.bfloat16 + + prompt = _parse_prompt_ids(prompt_ids) + if len(prompt) + max_new_tokens - 1 > max_seq: + raise ValueError( + f"prompt length {len(prompt)} with max_new_tokens={max_new_tokens} exceeds max_seq={max_seq}" + ) + + prompt_tensor = torch.tensor(prompt, dtype=torch.int64).view(1, -1) + input_spec = InputSpec( + tensors={ + "input_ids": TensorSpec( + tuple(prompt_tensor.shape), + torch.int64, + sampler=lambda s, d, g, t=prompt_tensor: t.clone().to(d), + ), + } + ) + + reference = CallableReference( + name="hf.Qwen3ForCausalLM", + fn=lambda inp, st: _hf_e2e_forward( + inp, + st, + model_path=hf_model_path, + num_layers=num_layers, + max_new_tokens=max_new_tokens, + hf_dtype=hf_dtype_t, + ), + ) + target = CallableTarget( + name="pytorch.qwen3_14b_e2e" if cpu_only else f"pypto.qwen3_14b_e2e[{platform}]", + fn=lambda inp, w: _run_e2e_target( + inp, + w, + max_new_tokens=max_new_tokens, + num_layers=num_layers, + max_seq=max_seq, + cpu_only=cpu_only, + platform=platform, + device_id=device_id, + ), + ) + + return ComparisonCase( + name="qwen3_14b.e2e", + reference=reference, + target=target, + input_spec=input_spec, + weight_adapter=PassthroughAdapter(), + selectors=[ + OutputSelector( + name="generated_ids", + ref_key="generated_ids", + tgt_key="generated_ids", + # int64 fits exactly in float32 for typical token IDs (< 2^24). + cast_to=torch.float32, + ), + OutputSelector( + name="last_logits", + ref_key="last_logits", + tgt_key="last_logits", + cast_to=torch.float32, + ), + ], + tolerance=Tolerance(atol=atol, rtol=rtol, pass_rate_threshold=1.0), + hf_weights=hf_model_path, + ) diff --git a/llm/testing/hf_compare/cases/qwen3_14b_e2e_fused.py b/llm/testing/hf_compare/cases/qwen3_14b_e2e_fused.py new file mode 100644 index 00000000..1d13e570 --- /dev/null +++ b/llm/testing/hf_compare/cases/qwen3_14b_e2e_fused.py @@ -0,0 +1,296 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Comparison case: Qwen3-14B end-to-end generation via the production LLMEngine. + +This case drives the same code path as ``llm.examples.qwen3_14b_npu_generate``: +``LLMEngine`` + ``PyptoQwen14BExecutor``, which compiles the per-layer prefill +kernel and the fused multi-layer decode kernel and manages the paged KV cache. + +The HF reference side is unchanged: a stock ``Qwen3ForCausalLM`` runs greedy +decoding on the same prompt with the same number of layers (40, the full +model -- the kernel is compiled for exactly that depth). + +Run via:: + + python -m llm.testing.hf_compare run qwen3_14b.e2e_fused \\ + -k hf_model_path=/data/linyifan/models/Qwen3-14B \\ + -k platform=a2a3 \\ + -k max_new_tokens=16 \\ + -k max_seq=512 +""" +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +import torch + +from ..base import ComparisonCase, InputSpec, OutputSelector, TensorSpec, Tolerance, register_case +from ..reference import CallableReference +from ..target import CallableTarget, _resolve_device_id +from ..weight_adapter import PassthroughAdapter + +# Reuse the HF reference forward + prompt parsing verbatim from the per-layer +# e2e case so behaviour stays in sync. +from .qwen3_14b_e2e import ( + DEFAULT_MODEL_PATH, + DEFAULT_PROMPT_IDS, + _hf_e2e_forward, + _parse_prompt_ids, +) + + +# Qwen3-14B is hardcoded into the fused decode kernel and the bundled weight +# stacking path (see llm/core/pypto_executor.py). Smaller depths require a +# kernel rebuild and are not supported by this case. +NUM_LAYERS = 40 + + +def _run_prefill_step( + executor: Any, + runtime_model: Any, + *, + request_id: str, + prompt_token_ids: list[int], + alloc: Any, + device: torch.device, +) -> torch.Tensor: + """Run one prefill step, return the final-token logits.""" + from llm.core.types import PrefillBatch + + prompt_len = len(prompt_token_ids) + token_tensor = torch.tensor(prompt_token_ids, dtype=torch.long, device=device).view(1, -1) + embeddings = executor.lookup_embeddings(runtime_model, token_tensor.view(-1)).view( + 1, prompt_len, -1 + ) + seq_lens = torch.tensor([prompt_len], dtype=torch.int32, device=device) + + result = executor.run_prefill( + runtime_model, + PrefillBatch( + request_ids=[request_id], + token_ids=token_tensor, + input_embeddings=embeddings, + seq_lens=seq_lens, + kv_allocations=[alloc], + ), + ) + return result.logits + + +def _run_decode_step( + executor: Any, + runtime_model: Any, + kv_cache_manager: Any, + *, + request_id: str, + next_token: int, + cur_seq_len: int, + alloc: Any, + device: torch.device, +) -> torch.Tensor: + """Run one decode step for a single batch=1 sequence, return logits.""" + from llm.core.types import DecodeBatch + + decode_token_tensor = torch.tensor([next_token], dtype=torch.long, device=device) + decode_embeddings = executor.lookup_embeddings(runtime_model, decode_token_tensor) + decode_seq_lens = torch.tensor([cur_seq_len], dtype=torch.int32, device=device) + block_table = kv_cache_manager.block_table_for_batch_padded([alloc]).to(device) + slot_mapping = kv_cache_manager.slot_mapping_for_batch([alloc]).to(device) + + result = executor.run_decode( + runtime_model, + DecodeBatch( + request_ids=[request_id], + token_ids=decode_token_tensor.unsqueeze(1), + hidden_states=decode_embeddings, + seq_lens=decode_seq_lens, + kv_allocations=[alloc], + block_table=block_table, + slot_mapping=slot_mapping, + ), + ) + return result.logits + + +def _run_e2e_fused_target( + inputs: Mapping[str, torch.Tensor], + weights: Mapping[str, torch.Tensor], # noqa: ARG001 -- engine loads weights itself + *, + model_dir: str, + max_new_tokens: int, + max_seq: int, + platform: str, + device_id: int | str | None, +) -> dict[str, torch.Tensor]: + """Drive `LLMEngine` + `PyptoQwen14BExecutor` for one greedy generation. + + Mirrors `LLMEngine.generate_batch` (llm/core/engine.py) but bypasses the + tokenizer to feed pre-tokenized prompt IDs supplied by the HF-compare + fixture, and captures the final decode logits for downstream comparison. + """ + from llm.core import LLMEngine, RuntimeConfig + from llm.core.kv_cache import KvCacheManager + from llm.core.pypto_executor import PyptoQwen14BExecutor + + resolved_device_id = _resolve_device_id(device_id) + + kv_cache_manager = KvCacheManager() + executor = PyptoQwen14BExecutor( + kv_cache_manager, + platform=platform, + device_id=resolved_device_id, + ) + engine = LLMEngine(kv_cache_manager=kv_cache_manager, executor=executor) + + model_id = "qwen3_14b" + engine.init_model( + model_id=model_id, + model_dir=str(model_dir), + model_format="huggingface", + runtime_config=RuntimeConfig( + page_size=256, + max_batch_size=16, + max_seq_len=max_seq, + device="cpu", + kv_dtype="bfloat16", + weight_dtype="float32", + ), + ) + + record = engine._models[model_id] # noqa: SLF001 -- intentional engine introspection + runtime_model = record.runtime_model + device = runtime_model.runtime.device + + input_ids = inputs["input_ids"].to(torch.long) + if input_ids.shape[0] != 1: + raise ValueError(f"qwen3_14b.e2e_fused expects batch=1, got {input_ids.shape[0]}") + prompt_token_ids: list[int] = input_ids[0].tolist() + prompt_len = len(prompt_token_ids) + if prompt_len + max_new_tokens - 1 > max_seq: + raise ValueError( + f"prompt_len + max_new_tokens - 1 = {prompt_len + max_new_tokens - 1} " + f"exceeds max_seq={max_seq}" + ) + + request_id = "req-0" + alloc = kv_cache_manager.allocate_for_prompt(model_id, request_id, prompt_len) + generated: list[int] = [] + + try: + last_logits = _run_prefill_step( + executor, runtime_model, + request_id=request_id, prompt_token_ids=prompt_token_ids, + alloc=alloc, device=device, + ) + next_token = int(torch.argmax(last_logits[0]).item()) + generated.append(next_token) + + cur_seq_len = prompt_len + for _ in range(1, max_new_tokens): + kv_cache_manager.ensure_one_more_slot(alloc) + cur_seq_len += 1 + last_logits = _run_decode_step( + executor, runtime_model, kv_cache_manager, + request_id=request_id, next_token=next_token, + cur_seq_len=cur_seq_len, alloc=alloc, device=device, + ) + next_token = int(torch.argmax(last_logits[0]).item()) + generated.append(next_token) + finally: + kv_cache_manager.free(alloc) + + return { + "generated_ids": torch.tensor([generated], dtype=torch.int64), + "last_logits": last_logits.float(), + } + + +@register_case("qwen3_14b.e2e_fused") +def build( + hf_model_path: str = DEFAULT_MODEL_PATH, + prompt_ids: str = DEFAULT_PROMPT_IDS, + platform: str = "a2a3", + device_id: int | str | None = None, + max_new_tokens: int = 16, + max_seq: int = 512, + atol: float = 5e-3, + rtol: float = 5e-3, + hf_dtype: str = "fp32", +) -> ComparisonCase: + device_id = int(device_id) if device_id is not None else None + max_new_tokens = int(max_new_tokens) + max_seq = int(max_seq) + atol = float(atol) + rtol = float(rtol) + hf_dtype_t = torch.float32 if hf_dtype == "fp32" else torch.bfloat16 + + prompt = _parse_prompt_ids(prompt_ids) + if len(prompt) + max_new_tokens - 1 > max_seq: + raise ValueError( + f"prompt length {len(prompt)} with max_new_tokens={max_new_tokens} " + f"exceeds max_seq={max_seq}" + ) + + prompt_tensor = torch.tensor(prompt, dtype=torch.int64).view(1, -1) + input_spec = InputSpec( + tensors={ + "input_ids": TensorSpec( + tuple(prompt_tensor.shape), + torch.int64, + sampler=lambda s, d, g, t=prompt_tensor: t.clone().to(d), + ), + } + ) + + reference = CallableReference( + name="hf.Qwen3ForCausalLM", + fn=lambda inp, st: _hf_e2e_forward( + inp, st, + model_path=hf_model_path, + num_layers=NUM_LAYERS, + max_new_tokens=max_new_tokens, + hf_dtype=hf_dtype_t, + ), + ) + target = CallableTarget( + name=f"pypto.qwen3_14b_e2e_fused[{platform}]", + fn=lambda inp, w: _run_e2e_fused_target( + inp, w, + model_dir=hf_model_path, + max_new_tokens=max_new_tokens, + max_seq=max_seq, + platform=platform, + device_id=device_id, + ), + ) + + return ComparisonCase( + name="qwen3_14b.e2e_fused", + reference=reference, + target=target, + input_spec=input_spec, + weight_adapter=PassthroughAdapter(), + selectors=[ + OutputSelector( + name="generated_ids", + ref_key="generated_ids", + tgt_key="generated_ids", + cast_to=torch.float32, + ), + OutputSelector( + name="last_logits", + ref_key="last_logits", + tgt_key="last_logits", + cast_to=torch.float32, + ), + ], + tolerance=Tolerance(atol=atol, rtol=rtol, pass_rate_threshold=1.0), + hf_weights=hf_model_path, + ) diff --git a/llm/testing/hf_compare/cases/qwen3_14b_prefill.py b/llm/testing/hf_compare/cases/qwen3_14b_prefill.py new file mode 100644 index 00000000..2bf342f3 --- /dev/null +++ b/llm/testing/hf_compare/cases/qwen3_14b_prefill.py @@ -0,0 +1,274 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Comparison case: Qwen3-14B prefill (single layer) vs Hugging Face Qwen3DecoderLayer. + +Port of ``examples/models/qwen3/14b/compare_prefill_with_hf.py``. Run via: + + python -m llm.testing.hf_compare run qwen3_14b.prefill \\ + -k hf_model_path=/data/linyifan/models/Qwen3-14B \\ + -k seq_len=128 -k platform=a2a3 -k batch=16 + # CPU-only: + python -m llm.testing.hf_compare run qwen3_14b.prefill -k cpu_only=true -k seq_len=128 +""" +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +import torch + +from ..base import ComparisonCase, InputSpec, OutputSelector, TensorSpec, Tolerance, register_case +from ..input_sampler import build_rope_tables, int_fill, uniform +from ..paged_kv import ( + compute_prefill_slot_mapping, + gather_prefill_kv, + init_block_table_identity, +) +from ..reference import CallableReference +from ..target import CallableTarget, PyPTOKernelTarget +from ._qwen3_14b_common import ( + DEFAULT_MODEL_PATH, + HEAD_DIM, + HIDDEN, + NUM_KV_HEADS, + PREFILL_SPEC_ORDER as SPEC_ORDER, + coerce_bool, + load_kernel_module, + single_layer_adapter, +) + + +# --------------------------------------------------------------------------- +# HF reference forward (full-sequence prefill) +# --------------------------------------------------------------------------- +def _hf_prefill_forward( + inputs: Mapping[str, torch.Tensor], + hf_state: Mapping[str, torch.Tensor], + *, + model_path: str, + hf_dtype: torch.dtype, +) -> dict[str, torch.Tensor]: + """Run Qwen3DecoderLayer for layer 0 on a full prefill sequence.""" + from transformers import AutoConfig + from transformers.cache_utils import DynamicCache + from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer + + cfg = AutoConfig.from_pretrained(model_path) + cfg._attn_implementation = "eager" + layer = Qwen3DecoderLayer(cfg, layer_idx=0).to(hf_dtype).eval() + + prefix = "model.layers.0." + sd = {k[len(prefix):]: v.to(hf_dtype) for k, v in hf_state.items() if k.startswith(prefix)} + missing, unexpected = layer.load_state_dict(sd, strict=False) + if missing: + raise RuntimeError(f"[qwen3.prefill] missing HF keys: {missing}") + if unexpected: + raise RuntimeError(f"[qwen3.prefill] unexpected HF keys: {unexpected}") + + batch = inputs["hidden_states"].shape[0] + seq_len = int(inputs["seq_lens"][0].item()) + + cos = inputs["rope_cos"][:seq_len].to(hf_dtype).unsqueeze(0).expand(batch, -1, -1) + sin = inputs["rope_sin"][:seq_len].to(hf_dtype).unsqueeze(0).expand(batch, -1, -1) + hs_in = inputs["hidden_states"][:, :seq_len, :].to(hf_dtype) + pos_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0).expand(batch, -1) + + causal_mask = torch.full((seq_len, seq_len), float("-inf"), dtype=hf_dtype) + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch, 1, -1, -1) + + cache = DynamicCache() + with torch.no_grad(): + out = layer( + hidden_states=hs_in, + attention_mask=causal_mask, + position_ids=pos_ids, + past_key_values=cache, + position_embeddings=(cos, sin), + ) + if isinstance(out, tuple): + out = out[0] + return { + "layer_out": out, # [B, S, H] + "k_cache": cache.layers[0].keys, # [B, nkv, S, head_dim] + "v_cache": cache.layers[0].values, # [B, nkv, S, head_dim] + } + + +def _select_npu_layer_out( + flat: torch.Tensor, *, batch: int, max_seq: int, seq_len: int, +) -> torch.Tensor: + return flat.view(batch, max_seq, HIDDEN)[:, :seq_len, :].clone() + + +# --------------------------------------------------------------------------- +# Case factory +# --------------------------------------------------------------------------- +@register_case("qwen3_14b.prefill") +def build( + hf_model_path: str = DEFAULT_MODEL_PATH, + cpu_only: Any = False, + platform: str = "a2a3", + device_id: int | str | None = None, + batch: int = 1, + seed: int = 0, + seq_len: int = 128, + max_seq: int | None = None, + atol: float = 5e-3, + rtol: float = 5e-3, + hf_dtype: str = "fp32", +) -> ComparisonCase: + cpu_only = coerce_bool(cpu_only) + device_id = int(device_id) if device_id is not None else None + batch = int(batch) + seed = int(seed) + seq_len = int(seq_len) + max_seq_arg = int(max_seq) if max_seq is not None else None + atol = float(atol) + rtol = float(rtol) + hf_dtype_t = torch.float32 if hf_dtype == "fp32" else torch.bfloat16 + + prefill = load_kernel_module( + "models/qwen3/14b/qwen3_14b_prefill.py", + "_qwen3_14b_prefill_kernel", + ) + + if batch <= 0: + raise ValueError(f"batch must be positive, got {batch}") + max_seq_eff = max_seq_arg or prefill.MAX_SEQ + if seq_len > max_seq_eff: + raise ValueError(f"seq_len={seq_len} > max_seq={max_seq_eff}") + max_blocks_per_seq = (max_seq_eff + prefill.BLOCK_SIZE - 1) // prefill.BLOCK_SIZE + num_blocks = batch * max_blocks_per_seq + cache_rows = num_blocks * NUM_KV_HEADS * prefill.BLOCK_SIZE + + # Capture block_size for the KV selector closure. + _block_size = int(prefill.BLOCK_SIZE) + + # ---- inputs ------------------------------------------------------------ + cos_tab, sin_tab = build_rope_tables(max_seq_eff, HEAD_DIM, base=1_000_000.0) + input_spec = InputSpec( + seed=seed, + tensors={ + "hidden_states": TensorSpec( + (batch, max_seq_eff, HIDDEN), torch.bfloat16, sampler=uniform() + ), + "seq_lens": TensorSpec((batch,), torch.int32, sampler=int_fill(seq_len)), + "block_table": TensorSpec((batch * max_blocks_per_seq,), torch.int32), + "slot_mapping": TensorSpec((batch * max_seq_eff,), torch.int32), + "k_cache": TensorSpec((cache_rows, HEAD_DIM), torch.bfloat16), + "v_cache": TensorSpec((cache_rows, HEAD_DIM), torch.bfloat16), + "out": TensorSpec((batch, max_seq_eff, HIDDEN), torch.bfloat16), + "rope_cos": TensorSpec( + tuple(cos_tab.shape), torch.float32, + sampler=lambda s, d, g: cos_tab.clone().to(d), + ), + "rope_sin": TensorSpec( + tuple(sin_tab.shape), torch.float32, + sampler=lambda s, d, g: sin_tab.clone().to(d), + ), + }, + ) + + # ---- weight adapter (identical layout to decode case) ------------------ + adapter = single_layer_adapter(layer_idx=0) + + # ---- reference --------------------------------------------------------- + reference = CallableReference( + name="hf.Qwen3DecoderLayer[prefill]", + fn=lambda inp, st: _hf_prefill_forward( + inp, st, model_path=hf_model_path, hf_dtype=hf_dtype_t, + ), + ) + + # ---- target ------------------------------------------------------------ + def _post_run(tensors: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + kv_kwargs = dict( + batch=batch, max_seq=max_seq_eff, seq_len=seq_len, + num_kv_heads=NUM_KV_HEADS, head_dim=HEAD_DIM, block_size=_block_size, + ) + return { + "layer_out": _select_npu_layer_out( + tensors["out"], batch=batch, max_seq=max_seq_eff, seq_len=seq_len, + ), + "k_cache_written": gather_prefill_kv( + tensors["k_cache"], tensors["slot_mapping"], **kv_kwargs, + ), + "v_cache_written": gather_prefill_kv( + tensors["v_cache"], tensors["slot_mapping"], **kv_kwargs, + ), + } + + if cpu_only: + def _golden(inp: Mapping[str, torch.Tensor], _w: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: + ref_inputs = {**_w, **inp} + prefill.golden_qwen3_14b_prefill(ref_inputs) + return { + "layer_out": _select_npu_layer_out( + ref_inputs["out"], batch=batch, max_seq=max_seq_eff, seq_len=seq_len, + ), + } + target = CallableTarget(name="pytorch.golden_qwen3_14b_prefill", fn=_golden) + # Golden does not write KV back; only layer_out is comparable. + selectors = [ + OutputSelector(name="layer_out", ref_key="layer_out", tgt_key="layer_out"), + ] + else: + target = PyPTOKernelTarget( + name=f"pypto.qwen3_14b_prefill[{platform}]", + build_program=lambda: prefill.build_qwen3_14b_prefill_program( + batch=batch, max_seq=max_seq_eff, + ), + spec_order=SPEC_ORDER, + platform=platform, + device_id=device_id, + post_run=_post_run, + ) + selectors = [ + OutputSelector(name="layer_out", ref_key="layer_out", tgt_key="layer_out"), + OutputSelector(name="k_cache", ref_key="k_cache", tgt_key="k_cache_written"), + OutputSelector(name="v_cache", ref_key="v_cache", tgt_key="v_cache_written"), + ] + + return ComparisonCase( + name="qwen3_14b.prefill", + reference=reference, + target=target, + input_spec=input_spec, + weight_adapter=adapter, + selectors=selectors, + tolerance=Tolerance(atol=atol, rtol=rtol), + hf_weights=hf_model_path, + on_inputs=lambda t: _init_paged_attention_inputs( + t, batch=batch, max_seq=max_seq_eff, block_size=prefill.BLOCK_SIZE, max_blocks_per_seq=max_blocks_per_seq, + ), + ) + + +# --------------------------------------------------------------------------- +# Paged-attention metadata initializers (block_table / slot_mapping) +# --------------------------------------------------------------------------- + + +def _init_paged_attention_inputs( + tensors: dict[str, torch.Tensor], + *, + batch: int, + max_seq: int, + block_size: int, + max_blocks_per_seq: int, +) -> None: + init_block_table_identity( + tensors["block_table"], batch=batch, max_blocks_per_seq=max_blocks_per_seq, + ) + compute_prefill_slot_mapping( + tensors["slot_mapping"], + batch=batch, max_seq=max_seq, + block_size=block_size, max_blocks_per_seq=max_blocks_per_seq, + ) diff --git a/llm/testing/hf_compare/comparator.py b/llm/testing/hf_compare/comparator.py new file mode 100644 index 00000000..8610762f --- /dev/null +++ b/llm/testing/hf_compare/comparator.py @@ -0,0 +1,87 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Tensor comparison metrics and the pass/fail decision.""" +from __future__ import annotations + +import torch + +from .base import OutputSelector, SelectorResult, Tolerance + + +class Comparator: + """Computes metrics and pass/fail for a single (ref, tgt) tensor pair.""" + + def __init__(self, tolerance: Tolerance): + self.tolerance = tolerance + + def compare( + self, + ref: torch.Tensor, + tgt: torch.Tensor, + selector: OutputSelector, + ) -> SelectorResult: + note = "" + if ref.shape != tgt.shape: + return SelectorResult( + selector=selector.name, + metrics={}, + passed=False, + note=f"shape mismatch ref={tuple(ref.shape)} tgt={tuple(tgt.shape)}", + ) + + ref_f = ref.detach().to(selector.cast_to).reshape(-1) + tgt_f = tgt.detach().to(selector.cast_to).reshape(-1) + diff = (ref_f - tgt_f).abs() + + metrics: dict[str, float] = {} + requested = set(self.tolerance.metrics) + + if "max_abs" in requested: + metrics["max_abs"] = float(diff.max().item()) if diff.numel() else 0.0 + if "mean_abs" in requested: + metrics["mean_abs"] = float(diff.mean().item()) if diff.numel() else 0.0 + if "max_rel" in requested: + denom = ref_f.abs().clamp(min=1e-6) + metrics["max_rel"] = float((diff / denom).max().item()) if diff.numel() else 0.0 + if "cosine" in requested: + metrics["cosine"] = _cosine(ref_f, tgt_f) + + threshold = self.tolerance.atol + self.tolerance.rtol * ref_f.abs() + within = diff <= threshold + pass_rate = float(within.float().mean().item()) if diff.numel() else 1.0 + if "pass_rate" in requested: + metrics["pass_rate"] = pass_rate + + passed = pass_rate >= self.tolerance.pass_rate_threshold + + offenders: list[tuple[int, float, float]] = [] + if not passed and self.tolerance.worst_offenders > 0 and diff.numel() > 0: + k = min(self.tolerance.worst_offenders, diff.numel()) + top = torch.topk(diff, k) + for idx in top.indices.tolist(): + offenders.append((int(idx), float(ref_f[idx].item()), float(tgt_f[idx].item()))) + + return SelectorResult( + selector=selector.name, + metrics=metrics, + passed=passed, + worst_offenders=offenders, + note=note, + ) + + +def _cosine(a: torch.Tensor, b: torch.Tensor) -> float: + if a.numel() == 0: + return 1.0 + na = a.norm() + nb = b.norm() + if na.item() == 0.0 or nb.item() == 0.0: + # Degenerate: define cosine as 1.0 iff both zero, else 0.0. + return 1.0 if (na.item() == 0.0 and nb.item() == 0.0) else 0.0 + return float((a @ b / (na * nb)).item()) diff --git a/llm/testing/hf_compare/input_sampler.py b/llm/testing/hf_compare/input_sampler.py new file mode 100644 index 00000000..e080edb4 --- /dev/null +++ b/llm/testing/hf_compare/input_sampler.py @@ -0,0 +1,57 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Helpers for synthesizing inputs (activations, KV history, RoPE tables).""" +from __future__ import annotations + +from collections.abc import Callable + +import torch + + +def uniform(low: float = -0.5, high: float = 0.5) -> Callable[[tuple[int, ...], torch.dtype, torch.Generator], torch.Tensor]: + """Return a sampler producing uniform ``[low, high)`` values.""" + scale = high - low + + def _fn(shape: tuple[int, ...], dtype: torch.dtype, g: torch.Generator) -> torch.Tensor: + t = torch.rand(shape, generator=g) * scale + low + return t.to(dtype) + return _fn + + +def normal(mean: float = 0.0, std: float = 1.0) -> Callable[[tuple[int, ...], torch.dtype, torch.Generator], torch.Tensor]: + def _fn(shape: tuple[int, ...], dtype: torch.dtype, g: torch.Generator) -> torch.Tensor: + t = torch.randn(shape, generator=g) * std + mean + return t.to(dtype) + return _fn + + +def constant(value: float) -> Callable[[tuple[int, ...], torch.dtype, torch.Generator], torch.Tensor]: + def _fn(shape: tuple[int, ...], dtype: torch.dtype, g: torch.Generator) -> torch.Tensor: # noqa: ARG001 + return torch.full(shape, value, dtype=dtype) + return _fn + + +def int_fill(value: int) -> Callable[[tuple[int, ...], torch.dtype, torch.Generator], torch.Tensor]: + def _fn(shape: tuple[int, ...], dtype: torch.dtype, g: torch.Generator) -> torch.Tensor: # noqa: ARG001 + return torch.full(shape, value, dtype=dtype) + return _fn + + +def build_rope_tables( + max_seq: int, + head_dim: int, + base: float = 1_000_000.0, + dtype: torch.dtype = torch.float32, +) -> tuple[torch.Tensor, torch.Tensor]: + """Build cos/sin tables compatible with HF rotate_half form.""" + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)) + pos = torch.arange(max_seq, dtype=torch.float32) + freqs = torch.outer(pos, inv_freq) + emb = torch.cat([freqs, freqs], dim=-1) + return emb.cos().to(dtype), emb.sin().to(dtype) diff --git a/llm/testing/hf_compare/paged_kv.py b/llm/testing/hf_compare/paged_kv.py new file mode 100644 index 00000000..37a40ab3 --- /dev/null +++ b/llm/testing/hf_compare/paged_kv.py @@ -0,0 +1,172 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Paged-attention KV-cache helpers shared by hf_compare cases. + +The kernels write KV into a flat row-major paged layout:: + + cache_row(pbid, kvh, offset) = (pbid * num_kv_heads + kvh) * block_size + offset + +For fused multi-layer kernels, layers are concatenated along axis 0; pass +``layer_offset_rows = layer_idx * layer_cache_rows`` to address a specific layer. + +This module provides: + +* ``init_block_table_identity`` / ``compute_*_slot_mapping`` -- build paged + metadata that case fixtures hand to kernels. +* ``select_kv_at_decode_slot`` -- gather the [batch, num_kv_heads, head_dim] + row written by a decode step. +* ``gather_prefill_kv`` -- assemble the [batch, num_kv_heads, seq_len, head_dim] + block written by prefill across all positions. +* ``paged_to_dense_history`` -- stitch paged cache rows into the dense + [batch, num_kv_heads, ctx_len, head_dim] tensors HF DynamicCache expects. +""" +from __future__ import annotations + +import torch + + +# --------------------------------------------------------------------------- +# Paged metadata initialisers (block_table / slot_mapping) +# --------------------------------------------------------------------------- +def init_block_table_identity( + block_table: torch.Tensor, *, batch: int, max_blocks_per_seq: int +) -> None: + """Fill block_table with logical->physical identity mapping (block i -> i).""" + block_table[:] = torch.arange(batch * max_blocks_per_seq, dtype=torch.int32) + + +def compute_decode_slot_mapping( + slot_mapping: torch.Tensor, + seq_lens: torch.Tensor, + *, + batch: int, + block_size: int, + max_blocks_per_seq: int, +) -> None: + """Set slot_mapping[b] for the current decode position (pos = seq_lens[b] - 1).""" + for b in range(batch): + pos = int(seq_lens[b].item()) - 1 + logical_block = pos // block_size + page_off = pos - logical_block * block_size + phys_block = b * max_blocks_per_seq + logical_block + slot_mapping[b] = phys_block * block_size + page_off + + +def compute_prefill_slot_mapping( + slot_mapping: torch.Tensor, + *, + batch: int, + max_seq: int, + block_size: int, + max_blocks_per_seq: int, +) -> None: + """Set slot_mapping[b*max_seq + pos] for every (b, pos) in [0, max_seq).""" + for b in range(batch): + base_block = b * max_blocks_per_seq + for pos in range(max_seq): + logical_block = pos // block_size + page_off = pos - logical_block * block_size + phys_block = base_block + logical_block + slot_mapping[b * max_seq + pos] = phys_block * block_size + page_off + + +# --------------------------------------------------------------------------- +# Cache-readback helpers (gather kernel-written KV for comparison) +# --------------------------------------------------------------------------- +def select_kv_at_decode_slot( + cache_flat: torch.Tensor, + slot_mapping: torch.Tensor, + *, + num_kv_heads: int, + head_dim: int, + block_size: int, + layer_offset_rows: int = 0, +) -> torch.Tensor: + """Read the row a decode step wrote, returning [batch, num_kv_heads, head_dim]. + + For a fused multi-layer cache pass ``layer_offset_rows = layer_idx * layer_cache_rows``. + """ + batch = int(slot_mapping.shape[0]) + out = torch.empty( + (batch, num_kv_heads, head_dim), + dtype=cache_flat.dtype, device=cache_flat.device, + ) + for b in range(batch): + slot = int(slot_mapping[b].item()) + slot_block = slot // block_size + slot_off = slot - slot_block * block_size + ofs = layer_offset_rows + slot_block * num_kv_heads * block_size + slot_off + for kvh in range(num_kv_heads): + out[b, kvh, :] = cache_flat[ofs + kvh * block_size, :] + return out.clone() + + +def gather_prefill_kv( + cache_flat: torch.Tensor, + slot_mapping: torch.Tensor, + *, + batch: int, + max_seq: int, + seq_len: int, + num_kv_heads: int, + head_dim: int, + block_size: int, +) -> torch.Tensor: + """Gather all prefill-written rows into a [B, num_kv_heads, seq_len, head_dim] tensor.""" + out = torch.empty( + (batch, num_kv_heads, seq_len, head_dim), + dtype=cache_flat.dtype, device=cache_flat.device, + ) + for b in range(batch): + for pos in range(seq_len): + slot = int(slot_mapping[b * max_seq + pos].item()) + slot_block = slot // block_size + slot_off = slot - slot_block * block_size + base = slot_block * num_kv_heads * block_size + slot_off + for kvh in range(num_kv_heads): + out[b, kvh, pos, :] = cache_flat[base + kvh * block_size, :] + return out.clone() + + +def paged_to_dense_history( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + block_table: torch.Tensor, + *, + batch: int, + ctx_len: int, + num_kv_heads: int, + head_dim: int, + block_size: int, + max_blocks_per_seq: int, + hf_dtype: torch.dtype, + layer_offset_rows: int = 0, +) -> tuple[torch.Tensor, torch.Tensor]: + """Stitch the paged KV cache into dense ``[B, num_kv_heads, ctx_len, head_dim]``. + + Walks ``block_table`` to assemble per-sequence histories, then truncates + each to ``ctx_len``. Used by HF references that seed ``DynamicCache``. + For a fused multi-layer cache pass ``layer_offset_rows`` to point at the + layer of interest. + """ + ctx_blocks = (ctx_len + block_size - 1) // block_size + k_hist = torch.empty((batch, num_kv_heads, ctx_len, head_dim), dtype=hf_dtype) + v_hist = torch.empty((batch, num_kv_heads, ctx_len, head_dim), dtype=hf_dtype) + for b in range(batch): + for kvh in range(num_kv_heads): + k_blocks: list[torch.Tensor] = [] + v_blocks: list[torch.Tensor] = [] + for sb in range(ctx_blocks): + pbid = int(block_table[b * max_blocks_per_seq + sb].item()) + row0 = layer_offset_rows + (pbid * num_kv_heads + kvh) * block_size + k_blocks.append(k_cache[row0:row0 + block_size, :]) + v_blocks.append(v_cache[row0:row0 + block_size, :]) + k_hist[b, kvh, :, :] = torch.cat(k_blocks, dim=0)[:ctx_len, :].to(hf_dtype) + v_hist[b, kvh, :, :] = torch.cat(v_blocks, dim=0)[:ctx_len, :].to(hf_dtype) + return k_hist, v_hist diff --git a/llm/testing/hf_compare/reference.py b/llm/testing/hf_compare/reference.py new file mode 100644 index 00000000..0d00c819 --- /dev/null +++ b/llm/testing/hf_compare/reference.py @@ -0,0 +1,103 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Reference-side wrappers. + +Two flavors are provided: + + - ``CallableReference``: wrap any function ``(inputs, hf_state) -> dict``. + Zero dependencies beyond torch. Suitable for custom golden refs. + + - ``HFModuleReference``: wrap a Hugging Face ``nn.Module`` (or factory for + one), handle ``.load_state_dict`` with prefix stripping, run forward in + a chosen dtype, and return selected outputs. Requires ``transformers`` + at call time (imported lazily). +""" +from __future__ import annotations + +from collections.abc import Callable, Mapping +from dataclasses import dataclass, field + +import torch + +from .base import ReferenceModel + + +# --------------------------------------------------------------------------- +# Callable reference (the most flexible option) +# --------------------------------------------------------------------------- +@dataclass +class CallableReference(ReferenceModel): + """Golden reference defined by a plain Python function. + + The forward function receives the InputSpec-materialized tensors plus the + raw HF state dict, and returns a dict of named output tensors. + """ + + name: str + fn: Callable[[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor]], dict[str, torch.Tensor]] + _state: Mapping[str, torch.Tensor] = field(default_factory=dict, init=False, repr=False) + + def prepare(self, hf_state: Mapping[str, torch.Tensor]) -> None: + self._state = hf_state + + def forward(self, inputs: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: + return self.fn(inputs, self._state) + + +# --------------------------------------------------------------------------- +# Hugging Face module reference +# --------------------------------------------------------------------------- +@dataclass +class HFModuleReference(ReferenceModel): + """Wrap a ``torch.nn.Module`` loaded with HF weights. + + Parameters: + name: report label (e.g. ``"hf.Qwen3DecoderLayer"``). + module_factory: callable that returns the uninitialized module. The + caller decides how to build it (usually from ``AutoConfig`` + + class constructor). + forward_fn: callable ``(module, inputs) -> dict`` that invokes the + module and extracts named output tensors. This is the only place + model-specific plumbing lives. + state_dict_prefix: prefix to strip from HF keys before + ``load_state_dict`` (e.g. ``"model.layers.0."``). + dtype: compute dtype (default fp32 for a tight reference). + """ + + name: str + module_factory: Callable[[], torch.nn.Module] + forward_fn: Callable[[torch.nn.Module, Mapping[str, torch.Tensor]], dict[str, torch.Tensor]] + state_dict_prefix: str = "" + dtype: torch.dtype = torch.float32 + strict_load: bool = True + + _module: torch.nn.Module | None = field(default=None, init=False, repr=False) + + def prepare(self, hf_state: Mapping[str, torch.Tensor]) -> None: + module = self.module_factory().to(self.dtype).eval() + sd: dict[str, torch.Tensor] = {} + prefix = self.state_dict_prefix + for k, v in hf_state.items(): + if prefix and not k.startswith(prefix): + continue + stripped = k[len(prefix):] if prefix else k + sd[stripped] = v.to(self.dtype) + missing, unexpected = module.load_state_dict(sd, strict=False) + if self.strict_load: + if missing: + raise RuntimeError(f"[{self.name}] missing HF keys: {missing}") + if unexpected: + raise RuntimeError(f"[{self.name}] unexpected HF keys: {unexpected}") + self._module = module + + def forward(self, inputs: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: + if self._module is None: + raise RuntimeError(f"[{self.name}] prepare() must be called first.") + with torch.no_grad(): + return self.forward_fn(self._module, inputs) diff --git a/llm/testing/hf_compare/runner.py b/llm/testing/hf_compare/runner.py new file mode 100644 index 00000000..3d739509 --- /dev/null +++ b/llm/testing/hf_compare/runner.py @@ -0,0 +1,140 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Runner: drive a ComparisonCase end-to-end.""" +from __future__ import annotations + +from collections.abc import Mapping + +import torch + +from .base import ( + ChainedComparisonCase, + CompareReport, + ComparisonCase, + HFWeightSource, + OutputSelector, + SelectorResult, +) +from .comparator import Comparator + + +def _resolve_hf_weights(source: HFWeightSource) -> Mapping[str, torch.Tensor]: + if source is None: + return {} + if callable(source): + return source() + if isinstance(source, str): + from .weight_adapter import load_hf_state + return load_hf_state(source) + return source + + +def _apply_selector(t: torch.Tensor, sel: OutputSelector) -> torch.Tensor: + if sel.slice_ is not None: + t = t[sel.slice_] + if sel.postprocess is not None: + t = sel.postprocess(t) + return t + + +def run_case(case: ComparisonCase) -> CompareReport: + hf_state = _resolve_hf_weights(case.hf_weights) + + # Materialize identical inputs for both sides. Target may mutate in place, + # so give each side its own clone. + raw_inputs = case.input_spec.materialize() + if case.on_inputs is not None: + case.on_inputs(raw_inputs) + ref_inputs = {k: v.clone() for k, v in raw_inputs.items()} + tgt_inputs = {k: v.clone() for k, v in raw_inputs.items()} + + # Prepare reference (HF) side. + case.reference.prepare(hf_state) + ref_out = case.reference.forward(ref_inputs) + + # Prepare target (PyPTO) side; adapt weights via declarative adapter. + kernel_weights = case.weight_adapter.adapt(hf_state) + case.target.prepare() + try: + tgt_out = case.target.run(tgt_inputs, kernel_weights) + finally: + case.target.teardown() + + # Compare each selected output. + comparator = Comparator(case.tolerance) + results: list[SelectorResult] = [] + for sel in case.selectors: + if sel.ref_key not in ref_out: + results.append( + SelectorResult( + selector=sel.name, + metrics={}, + passed=False, + note=f"ref_key {sel.ref_key!r} not in reference output (keys={sorted(ref_out)})", + ) + ) + continue + if sel.tgt_key not in tgt_out: + results.append( + SelectorResult( + selector=sel.name, + metrics={}, + passed=False, + note=f"tgt_key {sel.tgt_key!r} not in target output (keys={sorted(tgt_out)})", + ) + ) + continue + ref_t = _apply_selector(ref_out[sel.ref_key], sel) + tgt_t = _apply_selector(tgt_out[sel.tgt_key], sel) + results.append(comparator.compare(ref_t, tgt_t, sel)) + + passed = all(r.passed for r in results) and len(results) > 0 + report = CompareReport( + case_name=case.name, + passed=passed, + results=results, + meta={ + "reference": case.reference.name, + "target": case.target.name, + "num_selectors": len(case.selectors), + }, + ) + if case.on_report is not None: + case.on_report(report) + return report + + +def run_chain(chain: ChainedComparisonCase) -> list[CompareReport]: + reports: list[CompareReport] = [] + prev_ref_out: dict[str, torch.Tensor] = {} + for step in chain.steps: + if step.forward_map and prev_ref_out: + # Inject values from the previous step's reference outputs into + # this step's input spec by overriding materialize() via on_inputs. + mapping = step.forward_map + snapshot = {dst: prev_ref_out[src].clone() for src, dst in mapping.items() if src in prev_ref_out} + original_hook = step.case.on_inputs + + def _hook(inputs: dict[str, torch.Tensor], _snap=snapshot, _orig=original_hook) -> None: + for name, value in _snap.items(): + inputs[name] = value + if _orig is not None: + _orig(inputs) + + step.case.on_inputs = _hook + report = run_case(step.case) + reports.append(report) + if not report.passed and chain.stop_on_fail: + break + # Re-run the reference just to capture its outputs for the next step. + # (ComparisonCase.run did the work but did not return ref outputs, so + # we re-materialize on the fly. For now we skip and rely on tests + # using this chain to either keep state externally or extend the API.) + prev_ref_out = {} + return reports diff --git a/llm/testing/hf_compare/target.py b/llm/testing/hf_compare/target.py new file mode 100644 index 00000000..57a5f696 --- /dev/null +++ b/llm/testing/hf_compare/target.py @@ -0,0 +1,198 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Target-side wrappers: the system under test. + +Three flavors: + + - ``CallableTarget``: any ``(inputs, weights) -> dict`` function + (useful for CPU PyTorch goldens inside pypto-lib). + - ``PyPTOKernelTarget``: compile a pypto program once and execute it for + each run; maps input+weight dicts onto the + kernel's positional argument order. + - ``TorchTarget``: wrap an arbitrary ``nn.Module`` built from HF + weights (used as a sanity-check of the harness). +""" +from __future__ import annotations + +import os +from collections.abc import Callable, Mapping, Sequence +from dataclasses import dataclass, field +from typing import Any + +import torch + +from .base import TargetModel + + +# --------------------------------------------------------------------------- +# Callable target +# --------------------------------------------------------------------------- +@dataclass +class CallableTarget(TargetModel): + name: str + fn: Callable[[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor]], dict[str, torch.Tensor]] + + def prepare(self) -> None: + return None + + def run( + self, + inputs: Mapping[str, torch.Tensor], + weights: Mapping[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + return self.fn(inputs, weights) + + def teardown(self) -> None: + return None + + +# --------------------------------------------------------------------------- +# PyPTO kernel target +# --------------------------------------------------------------------------- +_BACKEND_MAP: dict[str, str] = { + # platform arg string -> pypto.backend.BackendType attribute name + "a2a3": "Ascend910B", + "a2a3sim": "Ascend910B", + "a5": "Ascend950", + "a5sim": "Ascend950", +} + + +def _resolve_device_id(device_id: int | str | None) -> int: + if device_id is not None: + return int(device_id) + + for env_name in ("ASCEND_DEVICE_ID", "DEVICE_ID"): + env_value = os.environ.get(env_name) + if env_value is not None and env_value != "": + return int(env_value) + return 0 + + +@dataclass +class PyPTOKernelTarget(TargetModel): + """Compile a pypto program once and execute it per run. + + Args: + name: report label. + build_program: callable returning the program IR (no args). + spec_order: ordered list of tensor names matching the kernel signature. + platform: one of ``a2a3 / a2a3sim / a5 / a5sim``. + device_id: NPU device index. If omitted, resolve from + ``ASCEND_DEVICE_ID`` / ``DEVICE_ID`` and fall back to ``0``. + output_keys: names of tensors to expose in the returned dict. Default + is all names in ``spec_order`` (so callers can read mutated + buffers like ``out`` / ``k_cache`` / ``v_cache``). + post_run: optional hook ``(tensors) -> dict`` that derives additional + named outputs from the mutated tensor dict (e.g. gather KV rows). + """ + + name: str + build_program: Callable[[], Any] + spec_order: Sequence[str] + platform: str = "a2a3" + device_id: int | str | None = None + output_keys: Sequence[str] | None = None + post_run: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] | None = None + + _compiled_dir: str | None = field(default=None, init=False, repr=False) + _execute_fn: Callable[..., None] | None = field(default=None, init=False, repr=False) + + def prepare(self) -> None: + if self._compiled_dir is not None: + return + from pypto import ir + from pypto.backend import BackendType + from pypto.runtime import execute_compiled + + backend_name = _BACKEND_MAP.get(self.platform) + if backend_name is None: + raise ValueError(f"Unknown platform {self.platform!r}; known: {sorted(_BACKEND_MAP)}") + backend = getattr(BackendType, backend_name) + + program = self.build_program() + compiled = ir.compile(program, backend_type=backend) + self._compiled_dir = compiled.output_dir + self._execute_fn = execute_compiled + + def run( + self, + inputs: Mapping[str, torch.Tensor], + weights: Mapping[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + if self._compiled_dir is None or self._execute_fn is None: + raise RuntimeError(f"[{self.name}] prepare() must be called first.") + + merged: dict[str, torch.Tensor] = {**weights, **inputs} + missing = [n for n in self.spec_order if n not in merged] + if missing: + raise KeyError(f"[{self.name}] missing tensors for spec_order: {missing}") + ordered = [merged[n] for n in self.spec_order] + + self._execute_fn( + self._compiled_dir, + ordered, + platform=self.platform, + device_id=_resolve_device_id(self.device_id), + ) + + keys = list(self.output_keys) if self.output_keys is not None else list(self.spec_order) + out: dict[str, torch.Tensor] = {k: merged[k] for k in keys if k in merged} + if self.post_run is not None: + out.update(self.post_run(merged)) + return out + + def teardown(self) -> None: + # Compilation artifacts live on disk; keep them for reuse. + return None + + +# --------------------------------------------------------------------------- +# Torch module target (sanity check / pure-CPU harness) +# --------------------------------------------------------------------------- +@dataclass +class TorchTarget(TargetModel): + name: str + module_factory: Callable[[], torch.nn.Module] + forward_fn: Callable[[torch.nn.Module, Mapping[str, torch.Tensor]], dict[str, torch.Tensor]] + state_dict_prefix: str = "" + dtype: torch.dtype = torch.float32 + strict_load: bool = True + + _module: torch.nn.Module | None = field(default=None, init=False, repr=False) + + def prepare(self) -> None: + return None + + def run( + self, + inputs: Mapping[str, torch.Tensor], + weights: Mapping[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + if self._module is None: + module = self.module_factory().to(self.dtype).eval() + prefix = self.state_dict_prefix + sd: dict[str, torch.Tensor] = {} + for k, v in weights.items(): + if prefix and not k.startswith(prefix): + continue + stripped = k[len(prefix):] if prefix else k + sd[stripped] = v.to(self.dtype) + missing, unexpected = module.load_state_dict(sd, strict=False) + if self.strict_load: + if missing: + raise RuntimeError(f"[{self.name}] missing keys: {missing}") + if unexpected: + raise RuntimeError(f"[{self.name}] unexpected keys: {unexpected}") + self._module = module + with torch.no_grad(): + return self.forward_fn(self._module, inputs) + + def teardown(self) -> None: + self._module = None diff --git a/llm/testing/hf_compare/weight_adapter.py b/llm/testing/hf_compare/weight_adapter.py new file mode 100644 index 00000000..3afaaea1 --- /dev/null +++ b/llm/testing/hf_compare/weight_adapter.py @@ -0,0 +1,206 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Declarative weight adapters: HF state_dict -> kernel-ready tensor dict. + +The common case is a per-tensor spec describing: + - where to read from (an HF key, optionally with a layer index prefix) + - a sequence of ops to transform the value (cast / transpose / reshape / view) + - the output key expected by the PyPTO kernel + +Example (Qwen3-14B decode, layer 0): + + DictAdapter({ + "input_rms_weight": + Map("input_layernorm.weight", ops=[View([1, HIDDEN]), Cast(torch.float32)]), + "wq": + Map("self_attn.q_proj.weight", ops=[Transpose(), Contiguous(), Cast(torch.bfloat16)]), + ... + }, prefix="model.layers.0.") +""" +from __future__ import annotations + +from collections.abc import Callable, Mapping, Sequence +from dataclasses import dataclass, field + +import torch + +from .base import WeightAdapter + + +# --------------------------------------------------------------------------- +# Ops: single-argument Callable[[Tensor], Tensor] +# --------------------------------------------------------------------------- +Op = Callable[[torch.Tensor], torch.Tensor] + + +def Cast(dtype: torch.dtype) -> Op: + def _op(t: torch.Tensor) -> torch.Tensor: + return t.to(dtype) + _op.__name__ = f"Cast({dtype})" + return _op + + +def Transpose(*dims: int) -> Op: + """Transpose the last two dims by default; pass dims for arbitrary perm.""" + def _op(t: torch.Tensor) -> torch.Tensor: + if not dims: + if t.ndim < 2: + return t + return t.t() + return t.permute(*dims) + _op.__name__ = f"Transpose{dims or ''}" + return _op + + +def Contiguous() -> Op: + def _op(t: torch.Tensor) -> torch.Tensor: + return t.contiguous() + _op.__name__ = "Contiguous" + return _op + + +def View(shape: Sequence[int]) -> Op: + shape_t = tuple(shape) + + def _op(t: torch.Tensor) -> torch.Tensor: + return t.view(*shape_t) + _op.__name__ = f"View({shape_t})" + return _op + + +def Reshape(shape: Sequence[int]) -> Op: + shape_t = tuple(shape) + + def _op(t: torch.Tensor) -> torch.Tensor: + return t.reshape(*shape_t) + _op.__name__ = f"Reshape({shape_t})" + return _op + + +def Clone() -> Op: + def _op(t: torch.Tensor) -> torch.Tensor: + return t.clone() + _op.__name__ = "Clone" + return _op + + +# --------------------------------------------------------------------------- +# Mapping spec +# --------------------------------------------------------------------------- +@dataclass +class Map: + """Read one HF key and transform it into one kernel tensor.""" + + src: str + ops: Sequence[Op] = () + + def apply(self, hf: Mapping[str, torch.Tensor], prefix: str) -> torch.Tensor: + key = prefix + self.src + if key not in hf: + raise KeyError(f"HF weight {key!r} not found (prefix={prefix!r}).") + t = hf[key] + for op in self.ops: + t = op(t) + return t + + +@dataclass +class Compute: + """Compute a kernel tensor from arbitrary HF inputs (escape hatch). + + ``fn`` receives the full HF mapping and the prefix and returns a tensor. + Use this for fused/merged weights, RoPE tables synthesized from config, + or constants not present in the state dict. + """ + + fn: Callable[[Mapping[str, torch.Tensor], str], torch.Tensor] + + def apply(self, hf: Mapping[str, torch.Tensor], prefix: str) -> torch.Tensor: + return self.fn(hf, prefix) + + +Spec = Map | Compute + + +# --------------------------------------------------------------------------- +# Adapter +# --------------------------------------------------------------------------- +@dataclass +class DictAdapter(WeightAdapter): + """Declarative adapter: a dict of ``kernel_key -> Map/Compute`` entries.""" + + mapping: dict[str, Spec] + prefix: str = "" + # Optional extras (e.g. zero-filled output buffers) merged as-is. + extras: dict[str, Callable[[], torch.Tensor]] = field(default_factory=dict) + + def adapt(self, hf_state: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: + out: dict[str, torch.Tensor] = {} + for key, spec in self.mapping.items(): + out[key] = spec.apply(hf_state, self.prefix) + for key, factory in self.extras.items(): + out[key] = factory() + return out + + def unadapt(self, kernel_state: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: + return {} + + +@dataclass +class PassthroughAdapter(WeightAdapter): + """Forward the raw HF state to the target unchanged.""" + + clone_tensors: bool = False + + def adapt(self, hf_state: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: + if self.clone_tensors: + return {k: v.clone() for k, v in hf_state.items()} + return dict(hf_state) + + def unadapt(self, kernel_state: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: + return dict(kernel_state) + + +# --------------------------------------------------------------------------- +# HF safetensors loading helper +# --------------------------------------------------------------------------- +def load_hf_state(model_path: str, key_filter: Callable[[str], bool] | None = None) -> dict[str, torch.Tensor]: + """Load safetensors shards from ``model_path`` and return a state dict. + + ``key_filter`` lets callers load only a subset (e.g. one layer). + """ + import json + from pathlib import Path + + from safetensors import safe_open + + idx_path = Path(model_path) / "model.safetensors.index.json" + if idx_path.exists(): + with open(idx_path) as f: + weight_map: dict[str, str] = json.load(f)["weight_map"] + keys = [k for k in weight_map if key_filter is None or key_filter(k)] + file_to_keys: dict[str, list[str]] = {} + for k in keys: + file_to_keys.setdefault(weight_map[k], []).append(k) + out: dict[str, torch.Tensor] = {} + for fname, subkeys in sorted(file_to_keys.items()): + path = Path(model_path) / fname + with safe_open(str(path), framework="pt", device="cpu") as f: + for k in subkeys: + out[k] = f.get_tensor(k) + return out + + # Single-file fallback. + path = Path(model_path) / "model.safetensors" + out = {} + with safe_open(str(path), framework="pt", device="cpu") as f: + for k in f.keys(): + if key_filter is None or key_filter(k): + out[k] = f.get_tensor(k) + return out diff --git a/tests/llm/testing/test_hf_compare_batch_config.py b/tests/llm/testing/test_hf_compare_batch_config.py new file mode 100644 index 00000000..ba1359d4 --- /dev/null +++ b/tests/llm/testing/test_hf_compare_batch_config.py @@ -0,0 +1,58 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- + +from importlib.machinery import ModuleSpec + +import pytest + +from llm.testing.hf_compare.cases import qwen3_14b_decode, qwen3_14b_prefill + + +def _stub_module(monkeypatch, case_module, *, max_seq: int, block_size: int): + class _FakeLoader: + def create_module(self, spec): + return None + + def exec_module(self, module): + module.MAX_SEQ = max_seq + module.BLOCK_SIZE = block_size + + monkeypatch.setattr( + case_module.importlib.util, + "spec_from_file_location", + lambda name, path: ModuleSpec(name=name, loader=_FakeLoader(), origin=str(path)), + ) + + +def test_decode_case_accepts_batch_override(monkeypatch): + _stub_module(monkeypatch, qwen3_14b_decode, max_seq=4096, block_size=64) + case = qwen3_14b_decode.build(cpu_only=True, batch=16, seq_len=128, max_seq=256) + + tensors = case.input_spec.tensors + assert tensors["hidden_states"].shape == (16, qwen3_14b_decode.HIDDEN) + assert tensors["seq_lens"].shape == (16,) + assert tensors["slot_mapping"].shape == (16,) + + +def test_prefill_case_accepts_batch_override(monkeypatch): + _stub_module(monkeypatch, qwen3_14b_prefill, max_seq=128, block_size=64) + case = qwen3_14b_prefill.build(cpu_only=True, batch=16, seq_len=128, max_seq=256) + + tensors = case.input_spec.tensors + assert tensors["hidden_states"].shape == (16, 256, qwen3_14b_prefill.HIDDEN) + assert tensors["seq_lens"].shape == (16,) + assert tensors["slot_mapping"].shape == (16 * 256,) + + +@pytest.mark.parametrize("builder", [qwen3_14b_decode.build, qwen3_14b_prefill.build]) +def test_batch_must_be_positive(builder, monkeypatch): + case_module = qwen3_14b_decode if builder is qwen3_14b_decode.build else qwen3_14b_prefill + _stub_module(monkeypatch, case_module, max_seq=4096, block_size=64) + with pytest.raises(ValueError, match="batch must be positive"): + builder(cpu_only=True, batch=0) diff --git a/tests/llm/testing/test_hf_compare_e2e.py b/tests/llm/testing/test_hf_compare_e2e.py new file mode 100644 index 00000000..ffd345e5 --- /dev/null +++ b/tests/llm/testing/test_hf_compare_e2e.py @@ -0,0 +1,201 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- + +import torch +import pytest +from importlib.machinery import ModuleSpec + +from llm.testing.hf_compare.cases import qwen3_14b_e2e +from llm.testing.hf_compare.target import _resolve_device_id + + +def test_qwen3_e2e_case_runs_with_passthrough_weights(monkeypatch): + fake_state = { + "model.embed_tokens.weight": torch.randn(32, 16), + "model.norm.weight": torch.randn(16), + "lm_head.weight": torch.randn(32, 16), + "model.layers.0.self_attn.q_proj.weight": torch.randn(16, 16), + } + expected_generated = torch.tensor([[11, 12]], dtype=torch.int64) + expected_logits = torch.tensor([[0.5, -1.0, 3.0]], dtype=torch.float32) + calls: dict[str, object] = {} + + def fake_ref(inputs, hf_state, **kwargs): + calls["ref_inputs"] = inputs["input_ids"].clone() + calls["ref_state_keys"] = sorted(hf_state) + calls["ref_kwargs"] = kwargs + return { + "generated_ids": expected_generated.clone(), + "last_logits": expected_logits.clone(), + } + + def fake_target(inputs, weights, **kwargs): + calls["target_inputs"] = inputs["input_ids"].clone() + calls["target_state_keys"] = sorted(weights) + calls["target_kwargs"] = kwargs + return { + "generated_ids": expected_generated.clone(), + "last_logits": expected_logits.clone(), + } + + monkeypatch.setattr(qwen3_14b_e2e, "_hf_e2e_forward", fake_ref) + monkeypatch.setattr(qwen3_14b_e2e, "_run_e2e_target", fake_target) + + case = qwen3_14b_e2e.build( + hf_model_path="/tmp/fake-model", + prompt_ids="7,8,9", + max_new_tokens=2, + num_layers=1, + max_seq=8, + cpu_only=True, + ) + case.hf_weights = fake_state + + report = case.run() + + assert report.passed + assert [result.selector for result in report.results] == ["generated_ids", "last_logits"] + assert torch.equal(calls["ref_inputs"], torch.tensor([[7, 8, 9]], dtype=torch.int64)) + assert torch.equal(calls["target_inputs"], torch.tensor([[7, 8, 9]], dtype=torch.int64)) + assert calls["ref_state_keys"] == sorted(fake_state) + assert calls["target_state_keys"] == sorted(fake_state) + assert calls["ref_kwargs"] == { + "model_path": "/tmp/fake-model", + "num_layers": 1, + "max_new_tokens": 2, + "hf_dtype": torch.float32, + } + assert calls["target_kwargs"] == { + "max_new_tokens": 2, + "num_layers": 1, + "max_seq": 8, + "cpu_only": True, + "platform": "a2a3", + "device_id": None, + } + + +def test_qwen3_e2e_build_validates_max_seq(): + with pytest.raises(ValueError, match="exceeds max_seq=4"): + qwen3_14b_e2e.build( + prompt_ids="1,2,3,4", + max_new_tokens=2, + max_seq=4, + ) + + +def test_state_for_num_layers_keeps_requested_prefixes(): + hf_state = { + "model.embed_tokens.weight": torch.ones(1), + "model.norm.weight": torch.ones(1), + "lm_head.weight": torch.ones(1), + "model.layers.0.a": torch.ones(1), + "model.layers.1.b": torch.ones(1), + "model.layers.2.c": torch.ones(1), + "something.else": torch.ones(1), + } + + kept = qwen3_14b_e2e._state_for_num_layers(hf_state, num_layers=2) + + assert sorted(kept) == [ + "lm_head.weight", + "model.embed_tokens.weight", + "model.layers.0.a", + "model.layers.1.b", + "model.norm.weight", + ] + + +def test_resolve_device_id_prefers_explicit_then_env(monkeypatch): + monkeypatch.delenv("ASCEND_DEVICE_ID", raising=False) + monkeypatch.delenv("DEVICE_ID", raising=False) + assert _resolve_device_id(None) == 0 + + monkeypatch.setenv("DEVICE_ID", "6") + assert _resolve_device_id(None) == 6 + + monkeypatch.setenv("ASCEND_DEVICE_ID", "4") + assert _resolve_device_id(None) == 4 + assert _resolve_device_id("9") == 9 + + +def test_qwen3_e2e_uses_resolved_device_id_for_npu(monkeypatch): + class _FakeRunner: + def __init__(self, dec, prefill, *, batch, max_seq, platform, device_id): + seen["runner_init"] = (batch, max_seq, platform, device_id) + + def run_prefill(self, *args, **kwargs): + hidden = args[0] + return hidden + + def run_decode(self, hidden_states, *args, **kwargs): + return hidden_states + + class _FakeModule: + HIDDEN = 4 + HEAD_DIM = 2 + NUM_KV_HEADS = 1 + + class _FakeDec: + pass + + seen: dict[str, object] = {} + monkeypatch.setenv("ASCEND_DEVICE_ID", "7") + monkeypatch.setattr(qwen3_14b_e2e, "_load_qwen3_modules", lambda: (_FakeDec, _FakeModule)) + monkeypatch.setattr(qwen3_14b_e2e, "_extract_layer_weights", lambda weights, **kwargs: {"wq": torch.zeros(4, 4, dtype=torch.bfloat16)}) + monkeypatch.setattr(qwen3_14b_e2e, "_build_rope", lambda max_seq, head_dim: (torch.zeros(max_seq, head_dim), torch.zeros(max_seq, head_dim))) + monkeypatch.setattr(qwen3_14b_e2e, "_embed_tokens", lambda input_ids, embed_weight: torch.zeros(input_ids.shape[0], input_ids.shape[1], 4, dtype=torch.bfloat16)) + monkeypatch.setattr(qwen3_14b_e2e, "_final_rmsnorm", lambda hidden, norm_weight: hidden.float()) + monkeypatch.setattr(qwen3_14b_e2e, "_lm_head_forward", lambda hidden, lm_weight: torch.tensor([[1.0, 0.0]], dtype=torch.float32)) + monkeypatch.setattr(qwen3_14b_e2e, "_greedy_sample", lambda logits: torch.tensor([0], dtype=torch.int64)) + monkeypatch.setattr(qwen3_14b_e2e, "_NPURunner", _FakeRunner) + + out = qwen3_14b_e2e._run_e2e_target( + {"input_ids": torch.tensor([[1, 2]], dtype=torch.int64)}, + { + "model.embed_tokens.weight": torch.zeros(8, 4), + "model.norm.weight": torch.ones(4), + "lm_head.weight": torch.zeros(2, 4), + }, + max_new_tokens=1, + num_layers=1, + max_seq=8, + cpu_only=False, + platform="a2a3", + device_id=None, + ) + + assert seen["runner_init"] == (1, 8, "a2a3", 7) + assert out["generated_ids"].tolist() == [[0]] + + +def test_load_qwen3_modules_loads_decode_and_prefill(monkeypatch): + seen: dict[str, object] = {} + + class _FakeLoader: + def create_module(self, spec): + return None + + def exec_module(self, module): + seen.setdefault("modules", []).append(module) + + def fake_spec_from_file_location(name, path): + seen.setdefault("names", []).append(name) + seen.setdefault("paths", []).append(str(path)) + return ModuleSpec(name=name, loader=_FakeLoader(), origin=str(path)) + + monkeypatch.setattr(qwen3_14b_e2e.importlib.util, "spec_from_file_location", fake_spec_from_file_location) + + dec, prefill = qwen3_14b_e2e._load_qwen3_modules() + + assert seen["names"] == ["_qwen3_14b_decode_kernel", "_qwen3_14b_prefill_kernel"] + assert seen["paths"][0].endswith("examples/models/qwen3/14b/qwen3_14b_decode.py") + assert seen["paths"][1].endswith("examples/models/qwen3/14b/qwen3_14b_prefill.py") + assert dec is seen["modules"][0] + assert prefill is seen["modules"][1]