Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
58612a0
[Executorch] Add non-flash SDPA for decode
kimishpatel Apr 1, 2026
beb7b11
Update on "[Executorch] Add non-flash SDPA for decode"
kimishpatel Apr 6, 2026
611b5fa
Update on "[Executorch] Add non-flash SDPA for decode"
kimishpatel Apr 6, 2026
4f386fa
Update on "[Executorch] Add non-flash SDPA for decode"
kimishpatel Apr 7, 2026
9200883
Update on "[Executorch] Add non-flash SDPA for decode"
kimishpatel Apr 8, 2026
c241432
Update on "[Executorch] Add non-flash SDPA for decode"
kimishpatel Apr 9, 2026
21e6f35
Update on "[Executorch] Add non-flash SDPA for decode"
kimishpatel Apr 13, 2026
8a3b816
Update on "[Executorch] Add non-flash SDPA for decode"
kimishpatel Apr 13, 2026
69eadac
Update on "[Executorch] Add non-flash SDPA for decode"
kimishpatel Apr 23, 2026
0c787cc
Update on "[Executorch] Add non-flash SDPA for decode"
kimishpatel Apr 24, 2026
ee5539a
Update on "[Executorch] Add non-flash SDPA for decode"
kimishpatel Apr 27, 2026
97f030f
Update on "[Executorch] Add non-flash SDPA for decode"
kimishpatel Apr 28, 2026
89d5c6b
Update on "[Executorch] Add non-flash SDPA for decode"
kimishpatel Apr 29, 2026
6c887f8
Update on "[Executorch] Add non-flash SDPA for decode"
kimishpatel Apr 30, 2026
412eeee
Update on "[Executorch] Add non-flash SDPA for decode"
kimishpatel Apr 30, 2026
05609a8
Update on "[Executorch] Add non-flash SDPA for decode"
kimishpatel May 1, 2026
056d9f3
Update on "[Executorch] Add non-flash SDPA for decode"
kimishpatel May 5, 2026
5be478b
Update on "[Executorch] Add non-flash SDPA for decode"
kimishpatel May 5, 2026
126e8bb
Update on "[Executorch] Add non-flash SDPA for decode"
kimishpatel May 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .ci/scripts/test_lora.sh
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ EXPECTED_QUANT_PREFIX="<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start
Okay, so I need to calculate 15% of 80."
EXPECTED_QUANT_LORA_PREFIX="
<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant
To calculate 15% of 80, we can multiply 80 by 15/100.
To calculate 15% of 80, we can multiply 80 by 15/100 and then simplify the fraction.
So, 15% of 80 is equal to (80 * 15) / 100 = 1200 / 100 = 12.
#### 12
The answer is: 12<|im_end|>"
Expand All @@ -168,6 +168,7 @@ So, 15% of 80 is 12.
#### 12
The answer is: 12<|im_end|>"


# Export Quantized PTE, PTD file, no LoRA.
# override base.lora_config=null to avoid creating a lora model
# and loading lora weights.
Expand Down Expand Up @@ -227,7 +228,7 @@ fi
NOW=$(date +"%H:%M:%S")
echo "Test 4: Quantized, program-data separation lora. Starting to run llama runner at ${NOW}"
# shellcheck source=/dev/null
cmake-out/examples/models/llama/llama_main --model_path=qwen_lora_math_q.pte --data_paths="qwen_foundation_q.ptd,qwen_lora_math_q.ptd" --prompt="${PROMPT}" ${RUNTIME_ARGS} > result.txt
cmake-out/examples/models/llama/llama_main --model_path=qwen_lora_math_q.pte --data_paths="qwen_foundation_q.ptd,qwen_lora_math_q.ptd" --prompt="${PROMPT}" ${RUNTIME_ARGS} --seq_len=104 > result.txt
NOW=$(date +"%H:%M:%S")
echo "Finished at ${NOW}"

Expand Down
292 changes: 292 additions & 0 deletions extension/llm/custom_ops/op_custom_sdpa_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

// Tests for the unfused SDPA code path (cpu_sdpa) dispatched when
// seq_len == 1 and inputs are non-quantized (the decode fast-path).
// These call custom_sdpa_out directly, not through sdpa_with_kv_cache.

#include <algorithm>
#include <cmath>
#include <limits>
#include <vector>

#include <executorch/extension/llm/custom_ops/op_sdpa.h>
#include <executorch/kernels/test/TestUtil.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>

#include <gtest/gtest.h>

using namespace ::testing;
using executorch::runtime::testing::TensorFactory;

namespace {

// Helper: call custom_sdpa_out. Inputs use [B, S, H, D] layout.
executorch::aten::Tensor call_custom_sdpa(
const executorch::aten::Tensor& q,
const executorch::aten::Tensor& k,
const executorch::aten::Tensor& v,
int64_t start_pos,
const std::optional<executorch::aten::Tensor>& attn_mask,
double dropout_p,
bool is_causal,
std::optional<double> scale,
executorch::aten::Tensor& out) {
executorch::runtime::KernelRuntimeContext ctx{};
return torch::executor::native::custom_sdpa_out(
ctx, q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, out);
}

/**
* Naive reference SDPA for [B, S, H, D] layout.
* Element [b,s,h,d] is at index b*S*H*D + s*H*D + h*D + d.
* Only first num_valid_keys KV entries are used.
*/
void compute_reference_sdpa(
const float* q_data,
int B,
int qS,
int qH,
int D,
const float* k_data,
int kvS,
int kvH,
const float* v_data,
float* out_data,
bool is_causal,
int64_t start_pos,
int num_valid_keys) {
float scale = 1.0f / std::sqrt(static_cast<float>(D));
int num_reps = qH / kvH;

for (int b = 0; b < B; b++) {
for (int h = 0; h < qH; h++) {
int kv_h = h / num_reps;
for (int qs = 0; qs < qS; qs++) {
// scores = Q @ K^T * scale
std::vector<float> scores(num_valid_keys);
for (int kvs = 0; kvs < num_valid_keys; kvs++) {
float dot = 0;
for (int d = 0; d < D; d++) {
float qv = q_data[b * qS * qH * D + qs * qH * D + h * D + d];
float kv = k_data[b * kvS * kvH * D + kvs * kvH * D + kv_h * D + d];
dot += qv * kv;
}
scores[kvs] = dot * scale;
}

// Causal mask
if (is_causal) {
int64_t valid = std::min(
start_pos + qs + 1, static_cast<int64_t>(num_valid_keys));
for (int64_t j = valid; j < num_valid_keys; j++) {
scores[j] = -std::numeric_limits<float>::infinity();
}
}

// Softmax
float max_val = *std::max_element(scores.begin(), scores.end());
float sum = 0;
for (auto& s : scores) {
s = std::exp(s - max_val);
sum += s;
}
if (sum > 0) {
for (auto& s : scores) {
s /= sum;
}
}

// output = scores @ V
for (int d = 0; d < D; d++) {
float val = 0;
for (int kvs = 0; kvs < num_valid_keys; kvs++) {
float vv = v_data[b * kvS * kvH * D + kvs * kvH * D + kv_h * D + d];
val += scores[kvs] * vv;
}
out_data[b * qS * qH * D + qs * qH * D + h * D + d] = val;
}
}
}
}
}

} // namespace

// With a single KV entry (start_pos=0), output must equal V[0].
TEST(OpCustomSdpaTest, DecodeSingleKV) {
TensorFactory<executorch::aten::ScalarType::Float> tf;

executorch::aten::Tensor q = tf.make(
{1, 1, 2, 4},
{0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009, 0.2566, 0.7936});

executorch::aten::Tensor k = tf.make(
{1, 1, 2, 4},
{0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 0.4414, 0.2969, 0.8317});

executorch::aten::Tensor v = tf.make(
{1, 1, 2, 4},
{0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895});

// softmax of a single score is always 1.0, so output == V
executorch::aten::Tensor expected = tf.make(
{1, 1, 2, 4},
{0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895});

executorch::aten::Tensor out = tf.zeros({1, 1, 2, 4});
call_custom_sdpa(q, k, v, /*start_pos=*/0, {}, 0.0, false, {}, out);
EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-6, 1e-6);
}

// Decode with 3 valid KV entries, verified against reference computation.
TEST(OpCustomSdpaTest, DecodeNonCausal) {
TensorFactory<executorch::aten::ScalarType::Float> tf;

// Q: [B=1, S=1, H=2, D=4]
executorch::aten::Tensor q = tf.make(
{1, 1, 2, 4},
{0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009, 0.2566, 0.7936});

// K, V: [B=1, kv_len=4, H=2, D=4], first 3 entries valid
executorch::aten::Tensor k = tf.make(
{1, 4, 2, 4},
{0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 0.4414, 0.2969, 0.8317,
0.1053, 0.2695, 0.3588, 0.1994, 0.5472, 0.0062, 0.9516, 0.0753,
0.8860, 0.5832, 0.3376, 0.8090, 0.5779, 0.9040, 0.5547, 0.3423,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000});

executorch::aten::Tensor v = tf.make(
{1, 4, 2, 4},
{0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895,
0.7539, 0.1952, 0.0050, 0.3068, 0.1165, 0.9103, 0.6440, 0.7071,
0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542, 0.3278,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000});

int64_t start_pos = 2;
int num_valid = 3;

std::vector<float> ref(8, 0.0f);
compute_reference_sdpa(
q.const_data_ptr<float>(),
1,
1,
2,
4,
k.const_data_ptr<float>(),
4,
2,
v.const_data_ptr<float>(),
ref.data(),
false,
start_pos,
num_valid);

executorch::aten::Tensor expected = tf.make({1, 1, 2, 4}, ref);
executorch::aten::Tensor out = tf.zeros({1, 1, 2, 4});
call_custom_sdpa(q, k, v, start_pos, {}, 0.0, false, {}, out);
EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-4, 1e-4);
}

// GQA: 4 query heads sharing 2 KV heads.
TEST(OpCustomSdpaTest, DecodeGQA) {
TensorFactory<executorch::aten::ScalarType::Float> tf;

// Q: [B=1, S=1, H_q=4, D=4]
executorch::aten::Tensor q = tf.make(
{1, 1, 4, 4},
{0.8823,
0.9150,
0.3829,
0.9593,
0.3904,
0.6009,
0.2566,
0.7936,
0.9408,
0.1332,
0.9346,
0.5936,
0.8694,
0.5677,
0.7411,
0.4294});

// K: [B=1, kv_len=3, H_kv=2, D=4]
executorch::aten::Tensor k =
tf.make({1, 3, 2, 4}, {0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 0.4414,
0.2969, 0.8317, 0.1053, 0.2695, 0.3588, 0.1994,
0.5472, 0.0062, 0.9516, 0.0753, 0.8860, 0.5832,
0.3376, 0.8090, 0.5779, 0.9040, 0.5547, 0.3423});

// V: [B=1, kv_len=3, H_kv=2, D=4]
executorch::aten::Tensor v =
tf.make({1, 3, 2, 4}, {0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814,
0.7886, 0.5895, 0.7539, 0.1952, 0.0050, 0.3068,
0.1165, 0.9103, 0.6440, 0.7071, 0.6581, 0.4913,
0.8913, 0.1447, 0.5315, 0.1587, 0.6542, 0.3278});

int64_t start_pos = 2;
int num_valid = 3;

std::vector<float> ref(16, 0.0f);
compute_reference_sdpa(
q.const_data_ptr<float>(),
1,
1,
4,
4,
k.const_data_ptr<float>(),
3,
2,
v.const_data_ptr<float>(),
ref.data(),
false,
start_pos,
num_valid);

executorch::aten::Tensor expected = tf.make({1, 1, 4, 4}, ref);
executorch::aten::Tensor out = tf.zeros({1, 1, 4, 4});
call_custom_sdpa(q, k, v, start_pos, {}, 0.0, false, {}, out);
EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-4, 1e-4);
}

// For seq_len=1, causal mask doesn't restrict any positions
// (all start_pos+1 entries are visible), so result must match non-causal.
TEST(OpCustomSdpaTest, DecodeCausalMatchesNonCausal) {
TensorFactory<executorch::aten::ScalarType::Float> tf;

executorch::aten::Tensor q = tf.make(
{1, 1, 2, 4},
{0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009, 0.2566, 0.7936});

executorch::aten::Tensor k = tf.make(
{1, 4, 2, 4},
{0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 0.4414, 0.2969, 0.8317,
0.1053, 0.2695, 0.3588, 0.1994, 0.5472, 0.0062, 0.9516, 0.0753,
0.8860, 0.5832, 0.3376, 0.8090, 0.5779, 0.9040, 0.5547, 0.3423,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000});

executorch::aten::Tensor v = tf.make(
{1, 4, 2, 4},
{0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895,
0.7539, 0.1952, 0.0050, 0.3068, 0.1165, 0.9103, 0.6440, 0.7071,
0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542, 0.3278,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000});

int64_t start_pos = 2;

executorch::aten::Tensor out_nc = tf.zeros({1, 1, 2, 4});
call_custom_sdpa(q, k, v, start_pos, {}, 0.0, false, {}, out_nc);

executorch::aten::Tensor out_c = tf.zeros({1, 1, 2, 4});
call_custom_sdpa(q, k, v, start_pos, {}, 0.0, true, {}, out_c);

EXPECT_TENSOR_CLOSE_WITH_TOL(out_c, out_nc, 1e-6, 1e-6);
}
Loading
Loading