Skip to content

Qualcomm AI Engine Direct - Refactor llama runner for dynamic IO dtypes#19146

Open
chenweng-quic wants to merge 1 commit into
pytorch:mainfrom
CodeLinaro:dev1/chenweng/support_llama_dynamic_io_dtypes
Open

Qualcomm AI Engine Direct - Refactor llama runner for dynamic IO dtypes#19146
chenweng-quic wants to merge 1 commit into
pytorch:mainfrom
CodeLinaro:dev1/chenweng/support_llama_dynamic_io_dtypes

Conversation

@chenweng-quic
Copy link
Copy Markdown
Collaborator

@chenweng-quic chenweng-quic commented Apr 27, 2026

Summary

To enable GPU backend support in the Llama runner, refactoring is required because the dtypes of kv_cache, attention_mask, and logits are currently hardcoded, preventing floating‑point models from running.
This PR focuses on removing the hardcode dtype for them.

Key changes

  • Remove template parameter from KVManager, LhdTokenGenerator,
    MultimodalPromptProcessor, and related runner classes
  • Detect kv_cache and attention_mask dtypes dynamically from MethodMeta at
    construction time instead of compile-time bitwidth detection
  • Switch to std::byte* pointer arithmetic with getDtypeSize() for all buffer
    offsets; add fill_mask() helper for multi-dtype attention mask filling
  • Update spec_prop pass for custom llama op for sharding case greater than 1

Test plan

python ${EXECUTORCH_ROOT}/backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_llama_stories_110m --model SM8650 --build_folder /local/mnt/workspace/chenweng/executorch/executorch/build-android --host mlgtw-linux --device bebcca9b --executorch_root ${EXECUTORCH_ROOT} --artifact_dir ./stories_110m_pte_size --llama_artifacts . --use_fp16
image

cc @cccclai @cbilgin @abhinaykukkadapu

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 27, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19146

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 2 New Failures, 2 Unclassified Failures

As of commit 8222b4b with merge base 824cbff (image):

NEW FAILURES - The following jobs have failed:

UNCLASSIFIED FAILURES - DrCI could not classify the following jobs because the workflow did not run on the merge base. The failures may be pre-existing on trunk or introduced by this PR:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 27, 2026
@chenweng-quic chenweng-quic added the module: qnn Issues related to Qualcomm's QNN delegate and code under backends/qualcomm/ label Apr 27, 2026
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@chenweng-quic chenweng-quic force-pushed the dev1/chenweng/support_llama_dynamic_io_dtypes branch from b06f047 to 66fafc5 Compare April 28, 2026 01:37
@chenweng-quic
Copy link
Copy Markdown
Collaborator Author

Hi @abhinaykukkadapu,
This PR touch the spec_prop_pass since the spec written in build_quant_io is restore by spec_prop_pass when num_sharding greater than 1.
e.g.,

after build_quant_io
input -> call_delegate1 (spec=uint16) -> llama.fallback (spec=uint16) -> call_delegate2 (spec=uint16) -> output
after spec_prop_pass 
input -> call_delegate1 (spec=uint16) -> llama.fallback (spec=float) -> call_delegate2 (spec=uint16) -> output

It will cause following error in op_fallback.cpp
image

To fix the dtype issue, I add logic in spec_prop_pass to avoid restoring.
I am not sure if this is correct fix, could you share your opinion?

@chenweng-quic
Copy link
Copy Markdown
Collaborator Author

chenweng-quic commented Apr 29, 2026

Hi @abhinaykukkadapu,
Could you tag claude to review this PR?
Thanks

@abhinaykukkadapu
Copy link
Copy Markdown
Contributor

@claude review this PR

@claude
Copy link
Copy Markdown

claude Bot commented Apr 29, 2026

Claude finished @abhinaykukkadapu's task in 3m 50s —— View job


PR Review: Qualcomm AI Engine Direct - Refactor llama runner for dynamic IO dtypes

  • Read and analyze all changed files
  • Review C++ runtime changes (kv_manager, token_generator, prompt_processor, etc.)
  • Review Python export/pass changes (build_quant_io, spec_prop_pass, llama.py, etc.)
  • Review new op_fallback.py file
  • Check for correctness, safety, and style issues
  • Post review findings

Overall Assessment

This is a well-motivated refactor that removes compile-time dtype dispatch (template <typename T>) in favor of runtime dtype detection via MethodMeta. The approach is sound — detecting kv_cache_dtype_ and attention_mask_dtype_ from the model metadata at construction time is cleaner and more extensible than the previous KvBitWidth enum + template instantiation approach.

The PR is large (33 files, ~800 additions / ~700 deletions) but most changes are mechanical: removing template parameters, switching from typed pointers to std::byte* with getDtypeSize() arithmetic, and threading MethodMeta through constructors.


Issues Found

1. Bug: fill_mask uses wrong "negative" value for float masks

kv_manager.cpp:51 — When use_pos_value=false and dtype is Float, the mask is filled with -65535.0. For quantized integer types, 0 is the "masked" value — but for float attention masks, the standard convention is -inf (or a very large negative like -1e9). Using -65535.0 as a magic number is fragile and may not match what the model expects if it was trained with standard float masking.

Consider using -std::numeric_limits<float>::infinity() or at least documenting why -65535.0 was chosen. If the model's attention implementation specifically expects this value, a comment would clarify the intent.

2. Bug: getDtypeSize missing Half / Float16 support

utils.h:30-51 — The getDtypeSize() function handles Float, Double, Int, Long, Byte, and UInt16 but does not handle Half (float16). Given the PR's stated goal is to support floating-point models (including fp16 via --use_fp16), this seems like a critical omission. If the kv_cache or attention_mask is exported with Half dtype, getDtypeSize() will hit the ET_CHECK_MSG(false, ...) path at runtime.

Fix this →

3. Bug: fill_mask also missing Half support

kv_manager.cpp:19-61 — Similarly, fill_mask() only handles UInt16, Byte, and Float. If the attention mask ends up as Half dtype, this will crash. Should add a Half case.

4. decoder_runner.h:logits_to_token also missing Half support

decoder_runner.h:74-89 — The logits_to_token switch only handles UInt16, Byte, and Float. If logits come back as Half, this will fail.

5. Redundant semicolon (style nit)

kv_manager.cpp:140 — There's a trailing double semicolon:

      std::byte* parent_ptr = attention_mask +
          pidx * metadata_.context_len * getDtypeSize(attention_mask_dtype_);
      ;

6. KV cache dtype detection heuristic is fragile

kv_manager.cpp:71-89 — The constructor iterates over inputs starting at index 2 to find the kv_cache by matching tensor_nbytes == expected_tensor_nbytes. This heuristic could misidentify a non-kv-cache tensor that happens to match the expected byte size. It also assumes that all inputs from index 2 onward that don't match this size are non-kv-cache tensors (like input_pos or sliding window attention mask), but there's no validation that at least one kv_cache tensor was actually found.

Additionally, the check tensor_meta->scalar_type() != kv_cache_dtype_ at line 75 will fire on non-kv-cache tensors (like input_pos which is Int), because once kv_cache_dtype_ is set, any subsequent tensor with a different dtype triggers the error message "Currently mixed scalar type of kv_cache is not allowed". This seems like a real bug — the non-kv-cache inputs between index 2 and the first kv_cache tensor will trigger a false mixed-dtype error.

Looking more carefully: the check at line 74 only fires when kv_cache_dtype_ != Undefined, and kv_cache_dtype_ is only set when a tensor matches the expected size (line 83-85). So the logic is:

  1. First, it scans for the first tensor matching the expected kv_cache size and records its dtype.
  2. Then, for subsequent tensors that don't match the expected size, the dtype comparison still runs but those tensors have different sizes so they'd mismatch but the code would still check their dtype...

Actually, the mixed-dtype check at line 74-78 runs before the size check at line 80-89. So if kv_cache_dtype_ is already set (from a previous kv_cache tensor), and then we encounter input_pos (which is Int dtype and different size), we'd enter the if at line 74 and hit the error. This seems like a real bug — the dtype check should be guarded by the size check.

Fix this →

7. spec_prop_pass change — approach seems reasonable

exir/passes/spec_prop_pass.py:82-95 — The change adds exir_ops.edge.llama.fallback.default alongside executorch_call_delegate in the condition that preserves existing specs instead of regenerating them. This makes sense: the llama.fallback op acts as a passthrough between delegates, and its spec should be preserved just like delegate call specs. The comment from @chenweng-quic explains the motivation well — spec_prop_pass was restoring the spec set by build_quant_io, breaking the dtype for the fallback op in the sharding case.

One concern: the import executorch.extension.llm.custom_ops.op_fallback at the top of spec_prop_pass.py (line 13) is a side-effect import that registers the llama.fallback op. This creates a dependency from the core exir package on an LLM-specific custom op. If op_fallback.py isn't available (e.g., in a minimal install), this import would fail and break spec_prop_pass for all users, not just QNN/llama users. Consider making this a lazy import or guarding it with a try/except.

Fix this →

8. op_fallback.py extraction is clean

The extraction of the llama.fallback op registration from model_sharding.py into its own op_fallback.py file is a good separation of concerns. The import in model_sharding.py (line 10) ensures backward compatibility.


CI Failures

The PR has 2 CI failures:

  • test-sqnr-static-llm-qnn-linux (smollm2_135m)test_static_llm_model
  • test-static-llama-qnn-linux (stories_260k_bc)test_llama_stories_260k

These may be related to the spec_prop_pass change or the build_quant_io changes affecting non-fp16 quantized models. Worth investigating whether the kv_cache dtype detection logic in the C++ runner works correctly for the 8-bit quantized case that these tests exercise.


Summary

The core refactoring approach (removing templates, using runtime dtype detection) is the right direction. The main concerns are:

  1. Critical: Missing Half dtype support in getDtypeSize(), fill_mask(), and logits_to_token() — this could cause runtime crashes for the fp16 use case this PR is enabling.
  2. Critical: The kv_cache dtype detection loop in KVManager constructor may trigger false "mixed scalar type" errors on non-kv-cache input tensors.
  3. Medium: The unconditional import of op_fallback in spec_prop_pass.py could break non-LLM users.
  4. Minor: -65535.0 magic number in float mask, trailing semicolon.

- Summary
To enable GPU backend support in the Llama runner, refactoring is required because the dtypes of kv_cache, attention_mask, and logits are currently hardcoded, preventing floating point models from running.
This PR fix the issue by removing the hardcoded dtype in the runner.
@chenweng-quic chenweng-quic force-pushed the dev1/chenweng/support_llama_dynamic_io_dtypes branch from 13fbb40 to 8222b4b Compare May 17, 2026 11:47
@chenweng-quic chenweng-quic requested a review from psiddh as a code owner May 17, 2026 11:47
@chenweng-quic
Copy link
Copy Markdown
Collaborator Author

Hi @psiddh,
Could you take a look of this PR?
Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: qnn Issues related to Qualcomm's QNN delegate and code under backends/qualcomm/

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants