Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
f46b5d2
Add quantized input support to cpu_sdpa
kimishpatel Apr 1, 2026
01ab6c4
Update on "Add quantized input support to cpu_sdpa"
kimishpatel Apr 6, 2026
fc9ff9a
Update on "Add quantized input support to cpu_sdpa"
kimishpatel Apr 6, 2026
f45f8c8
Update on "Add quantized input support to cpu_sdpa"
kimishpatel Apr 7, 2026
0676be7
Update on "Add quantized input support to cpu_sdpa"
kimishpatel Apr 8, 2026
fd39b36
Update on "Add quantized input support to cpu_sdpa"
kimishpatel Apr 9, 2026
f5a702d
Update on "Add quantized input support to cpu_sdpa"
kimishpatel Apr 13, 2026
4141178
Update on "Add quantized input support to cpu_sdpa"
kimishpatel Apr 13, 2026
6e563e4
Update on "Add quantized input support to cpu_sdpa"
kimishpatel Apr 23, 2026
468f5b4
Update on "Add quantized input support to cpu_sdpa"
kimishpatel Apr 24, 2026
2b56ad0
Update on "Add quantized input support to cpu_sdpa"
kimishpatel Apr 27, 2026
6ef4741
Update on "Add quantized input support to cpu_sdpa"
kimishpatel Apr 28, 2026
b4edc11
Update on "Add quantized input support to cpu_sdpa"
kimishpatel Apr 29, 2026
8ca4001
Update on "Add quantized input support to cpu_sdpa"
kimishpatel Apr 30, 2026
adef0e2
Update on "Add quantized input support to cpu_sdpa"
kimishpatel Apr 30, 2026
cb2d0fe
Update on "Add quantized input support to cpu_sdpa"
kimishpatel May 1, 2026
69b6d08
Update on "Add quantized input support to cpu_sdpa"
kimishpatel May 5, 2026
99b8424
Update on "Add quantized input support to cpu_sdpa"
kimishpatel May 5, 2026
69f358a
Update on "Add quantized input support to cpu_sdpa"
kimishpatel May 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 186 additions & 2 deletions extension/llm/custom_ops/op_custom_sdpa_test.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand All @@ -7,8 +7,9 @@
*/

// Tests for the unfused SDPA code path (cpu_sdpa) dispatched when
// seq_len == 1 and inputs are non-quantized (the decode fast-path).
// These call custom_sdpa_out directly, not through sdpa_with_kv_cache.
// seq_len == 1 (the decode fast-path). Covers both float and quantized
// inputs. These call custom_sdpa_out / custom_quantized_sdpa_out
// directly, not through sdpa_with_kv_cache.

#include <algorithm>
#include <cmath>
Expand Down Expand Up @@ -114,6 +115,55 @@
}
}

/**
* Dequantize int8 tensor in [B, S, H, D] layout using per-token
* scales/zero_points in [B, S, H, 1] layout.
* dequant(x) = (x - zero_point) * scale
*/
void dequantize_per_token(
const int8_t* data, int B, int S, int H, int D,
const float* scales,
const int8_t* zps,
float* out) {
for (int b = 0; b < B; b++) {
for (int s = 0; s < S; s++) {
for (int h = 0; h < H; h++) {
int param_idx = b * S * H + s * H + h;
float sc = scales[param_idx];
float zp = static_cast<float>(zps[param_idx]);
for (int d = 0; d < D; d++) {
int idx = b * S * H * D + s * H * D + h * D + d;
out[idx] = (static_cast<float>(data[idx]) - zp) * sc;
}
}
}
}
}

// Helper: call custom_quantized_sdpa_out. Inputs use [B, S, H, D] layout.
executorch::aten::Tensor call_custom_quantized_sdpa(
const executorch::aten::Tensor& q,
const executorch::aten::Tensor& k,
const executorch::aten::Tensor& v,
int64_t start_pos,
const std::optional<executorch::aten::Tensor>& attn_mask,
double dropout_p,
bool is_causal,
std::optional<double> scale,
const std::optional<executorch::aten::Tensor>& q_zp,
const std::optional<executorch::aten::Tensor>& q_sc,
const std::optional<executorch::aten::Tensor>& k_zp,
const std::optional<executorch::aten::Tensor>& k_sc,
const std::optional<executorch::aten::Tensor>& v_zp,
const std::optional<executorch::aten::Tensor>& v_sc,
executorch::aten::Tensor& out) {
executorch::runtime::KernelRuntimeContext ctx{};
return torch::executor::native::custom_quantized_sdpa_out(
ctx, q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale,
q_zp, q_sc, k_zp, k_sc, v_zp, v_sc,
/*is_seq_at_dim_1=*/false, out);
}

} // namespace

// With a single KV entry (start_pos=0), output must equal V[0].
Expand Down Expand Up @@ -263,3 +313,137 @@

EXPECT_TENSOR_CLOSE_WITH_TOL(out_c, out_nc, 1e-6, 1e-6);
}

// Quantized decode: int8 Q/K/V with per-token scales and zero_points,
// verified against dequantize-then-float-SDPA reference.
TEST(OpCustomSdpaTest, DecodeQuantized) {
TensorFactory<executorch::aten::ScalarType::Char> tfChar;
TensorFactory<executorch::aten::ScalarType::Float> tfFloat;

// Q: [B=1, S=1, H=2, D=4] as int8
auto q = tfChar.make(
{1, 1, 2, 4},
{10, 20, -5, 15, -10, 5, 25, -20});

// K: [B=1, kv_len=3, H=2, D=4] as int8
auto k = tfChar.make(
{1, 3, 2, 4},
{8, -12, 18, 5, -3, 22, -8, 14,
15, 7, -20, 10, 12, -15, 9, 6,
-5, 25, 3, -10, 20, 8, -12, 17});

// V: [B=1, kv_len=3, H=2, D=4] as int8
auto v = tfChar.make(
{1, 3, 2, 4},
{5, 15, -8, 20, 10, -5, 18, 12,
-12, 8, 22, -3, 7, 20, -10, 15,
18, -5, 10, 3, -8, 12, 5, -20});

// Per-token scales [B, S/kv, H, 1] and zero_points [B, S/kv, H, 1]
auto q_sc = tfFloat.make({1, 1, 2, 1}, {0.05f, 0.05f});
auto k_sc = tfFloat.make({1, 3, 2, 1},
{0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f});
auto v_sc = tfFloat.make({1, 3, 2, 1},
{0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f});
auto q_zp = tfChar.make({1, 1, 2, 1}, {0, 0});
auto k_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0});
auto v_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0});

int64_t start_pos = 2;
int num_valid = 3;

// Dequantize and compute float reference
std::vector<float> q_deq(8), k_deq(24), v_deq(24);
dequantize_per_token(
q.const_data_ptr<int8_t>(), 1, 1, 2, 4,
q_sc.const_data_ptr<float>(), q_zp.const_data_ptr<int8_t>(),
q_deq.data());
dequantize_per_token(
k.const_data_ptr<int8_t>(), 1, 3, 2, 4,
k_sc.const_data_ptr<float>(), k_zp.const_data_ptr<int8_t>(),
k_deq.data());
dequantize_per_token(
v.const_data_ptr<int8_t>(), 1, 3, 2, 4,
v_sc.const_data_ptr<float>(), v_zp.const_data_ptr<int8_t>(),
v_deq.data());

std::vector<float> ref(8, 0.0f);
compute_reference_sdpa(
q_deq.data(), 1, 1, 2, 4,
k_deq.data(), 3, 2,
v_deq.data(),
ref.data(), false, start_pos, num_valid);

auto expected = tfFloat.make({1, 1, 2, 4}, ref);
auto out = tfFloat.zeros({1, 1, 2, 4});
call_custom_quantized_sdpa(
q, k, v, start_pos, {}, 0.0, false, {},
q_zp, q_sc, k_zp, k_sc, v_zp, v_sc, out);
EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-3, 1e-3);
}

// Quantized GQA decode: 4 query heads sharing 2 KV heads, int8 inputs.
TEST(OpCustomSdpaTest, DecodeQuantizedGQA) {
TensorFactory<executorch::aten::ScalarType::Char> tfChar;
TensorFactory<executorch::aten::ScalarType::Float> tfFloat;

// Q: [B=1, S=1, H_q=4, D=4] as int8
auto q = tfChar.make(
{1, 1, 4, 4},
{10, 20, -5, 15, -10, 5, 25, -20,
8, -3, 12, 7, -15, 18, 4, -8});

// K: [B=1, kv_len=3, H_kv=2, D=4] as int8
auto k = tfChar.make(
{1, 3, 2, 4},
{8, -12, 18, 5, -3, 22, -8, 14,
15, 7, -20, 10, 12, -15, 9, 6,
-5, 25, 3, -10, 20, 8, -12, 17});

// V: [B=1, kv_len=3, H_kv=2, D=4] as int8
auto v = tfChar.make(
{1, 3, 2, 4},
{5, 15, -8, 20, 10, -5, 18, 12,
-12, 8, 22, -3, 7, 20, -10, 15,
18, -5, 10, 3, -8, 12, 5, -20});

auto q_sc = tfFloat.make({1, 1, 4, 1}, {0.05f, 0.05f, 0.05f, 0.05f});
auto k_sc = tfFloat.make({1, 3, 2, 1},
{0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f});
auto v_sc = tfFloat.make({1, 3, 2, 1},
{0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f});
auto q_zp = tfChar.make({1, 1, 4, 1}, {0, 0, 0, 0});
auto k_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0});
auto v_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0});

int64_t start_pos = 2;
int num_valid = 3;

std::vector<float> q_deq(16), k_deq(24), v_deq(24);
dequantize_per_token(
q.const_data_ptr<int8_t>(), 1, 1, 4, 4,
q_sc.const_data_ptr<float>(), q_zp.const_data_ptr<int8_t>(),
q_deq.data());
dequantize_per_token(
k.const_data_ptr<int8_t>(), 1, 3, 2, 4,
k_sc.const_data_ptr<float>(), k_zp.const_data_ptr<int8_t>(),
k_deq.data());
dequantize_per_token(
v.const_data_ptr<int8_t>(), 1, 3, 2, 4,
v_sc.const_data_ptr<float>(), v_zp.const_data_ptr<int8_t>(),
v_deq.data());

std::vector<float> ref(16, 0.0f);
compute_reference_sdpa(
q_deq.data(), 1, 1, 4, 4,
k_deq.data(), 3, 2,
v_deq.data(),
ref.data(), false, start_pos, num_valid);

auto expected = tfFloat.make({1, 1, 4, 4}, ref);
auto out = tfFloat.zeros({1, 1, 4, 4});
call_custom_quantized_sdpa(
q, k, v, start_pos, {}, 0.0, false, {},
q_zp, q_sc, k_zp, k_sc, v_zp, v_sc, out);
EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-3, 1e-3);
}
8 changes: 5 additions & 3 deletions extension/llm/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -412,15 +412,17 @@
InvalidArgument,
output);

bool use_unfused_sdpa = q.scalar_type() != ScalarType::Char &&
seq_len == 1;
bool use_unfused_sdpa = seq_len == 1;
if (use_unfused_sdpa) {
ET_SWITCH_FLOAT_TYPES(
output.scalar_type(), ctx, "sdpa", CTYPE, [&] {
sdpa::impl::cpu_sdpa<CTYPE>(
ctx, output, q, k, v, is_causal, attn_mask, scale,
seq_dim,
start_pos, num_keys_for_causal_attention);
start_pos, num_keys_for_causal_attention,
q_zero_points, q_scales,
k_zero_points, k_scales,
v_zero_points, v_scales);
});
} else {
ET_SWITCH_FLOAT_TYPES(
Expand Down
Loading
Loading