Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
d70fe8e
Use unfused SDPA for short sequences (q_len <= 128 or kv_len <= 128)
kimishpatel Apr 1, 2026
1083d69
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_l…
kimishpatel Apr 6, 2026
b08de1b
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_l…
kimishpatel Apr 6, 2026
1a34cd4
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_l…
kimishpatel Apr 7, 2026
11b9a55
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_l…
kimishpatel Apr 8, 2026
4b9da2b
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_l…
kimishpatel Apr 9, 2026
6acda1b
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_l…
kimishpatel Apr 13, 2026
98ebd08
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_l…
kimishpatel Apr 13, 2026
e4ba4cf
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_l…
kimishpatel Apr 14, 2026
d78292a
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_l…
kimishpatel Apr 23, 2026
f2b105d
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_l…
kimishpatel Apr 24, 2026
f88685e
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_l…
kimishpatel Apr 27, 2026
a073a5b
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_l…
kimishpatel Apr 28, 2026
0d396d0
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_l…
kimishpatel Apr 29, 2026
cc3c3fe
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_l…
kimishpatel Apr 30, 2026
ee90ded
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_l…
kimishpatel Apr 30, 2026
7914266
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_l…
kimishpatel May 1, 2026
bd8b985
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_l…
kimishpatel May 5, 2026
7b01818
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_l…
kimishpatel May 5, 2026
8e6f0c3
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_l…
kimishpatel May 5, 2026
1d77b03
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_l…
kimishpatel May 5, 2026
4474b91
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_l…
kimishpatel May 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .ci/scripts/test_lora.sh
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,12 @@ else
fi

### QUANTIZATION & PROGRAM DATA SEPARATION ###
EXPECTED_QUANT_PREFIX="<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant:
<think>
EXPECTED_QUANT_PREFIX="<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant: me
Okay, so I need to calculate 15% of 80."
EXPECTED_QUANT_LORA_PREFIX="
<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant
To calculate 15% of 80, we can multiply 80 by 15/100 and then simplify the fraction.
So, 15% of 80 is equal to (80 * 15) / 100 = 1200 / 100 = 12.
To calculate 15% of 80, we can multiply 80 by 15/100.
So, 15% of 80 is equal to 80 * 15/100 = 12.
#### 12
The answer is: 12<|im_end|>"
EXPECTED_QUANT_LORA_ALTERNATE_PREFIX="
Expand All @@ -169,6 +168,7 @@ So, 15% of 80 is 12.
The answer is: 12<|im_end|>"



# Export Quantized PTE, PTD file, no LoRA.
# override base.lora_config=null to avoid creating a lora model
# and loading lora weights.
Expand Down Expand Up @@ -228,7 +228,7 @@ fi
NOW=$(date +"%H:%M:%S")
echo "Test 4: Quantized, program-data separation lora. Starting to run llama runner at ${NOW}"
# shellcheck source=/dev/null
cmake-out/examples/models/llama/llama_main --model_path=qwen_lora_math_q.pte --data_paths="qwen_foundation_q.ptd,qwen_lora_math_q.ptd" --prompt="${PROMPT}" ${RUNTIME_ARGS} --seq_len=104 > result.txt
cmake-out/examples/models/llama/llama_main --model_path=qwen_lora_math_q.pte --data_paths="qwen_foundation_q.ptd,qwen_lora_math_q.ptd" --prompt="${PROMPT}" ${RUNTIME_ARGS} > result.txt
NOW=$(date +"%H:%M:%S")
echo "Finished at ${NOW}"

Expand Down
3 changes: 1 addition & 2 deletions .ci/scripts/test_lora_multimethod.sh
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ EXPECTED_LORA_PREFIX="
<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant
To calculate 15% of 80"

EXPECTED_BASE_PREFIX="<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant:
<think>
EXPECTED_BASE_PREFIX="<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant: me
Okay, so I need to calculate 15% of 80."

### TEST 1: Run lora_forward method ###
Expand Down
7 changes: 6 additions & 1 deletion extension/llm/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,12 @@ Tensor& custom_sdpa_out_impl(
InvalidArgument,
output);

bool use_unfused_sdpa = seq_len == 1;
// Quantized GEMM kernels may not handle non-contiguous per-head strides
// correctly when seq_dim=ONE and seq_len > 1, so keep the conservative
// condition for quantized inputs.
bool is_quantized = q.scalar_type() == ScalarType::Char;
bool use_unfused_sdpa = (!is_quantized) &&
(seq_len <= 128 || num_keys_for_causal_attention <= 128);
if (use_unfused_sdpa) {
ET_SWITCH_FLOAT_TYPES(output.scalar_type(), ctx, "sdpa", CTYPE, [&] {
sdpa::impl::cpu_sdpa<CTYPE>(
Expand Down
Loading