Skip to content
Open
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
350 changes: 345 additions & 5 deletions extension/llm/custom_ops/bench_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,17 +172,145 @@ void run_standard_sdpa(
});
}

// ONNX Runtime GQA-style SDPA, faithfully ported from
// onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h.
// Differences from run_standard_sdpa:
// 1. Scale in GEMM alpha (no separate scaling pass)
// 2. Scores buffer padded to max_seq_len cols (ONNX's present_buffer_seq_len)
// 3. Causal mask: zero out future positions, softmax on valid window only
// 4. Output in [B, S, Hq, D] with stride Hq*D (ONNX's interleaved BNSH->BSNH)
//
// When is_transposed=true, inputs are [B,H,S,D]; output is [B,S,Hq,D].
// When is_transposed=false, inputs are [B,S,H,D]; output is [B,S,Hq,D].
// Output is always [B, S, Hq, D] to match ONNX's actual output format.
void run_onnx_gqa_sdpa(
const float* q_data,
const float* k_data,
const float* v_data,
float* out_data, // always [B, q_seq_len, Hq, D]
float* scores_buf, // must hold batch*Hq*q_seq_len*max_seq_len floats
int64_t batch,
int64_t Hq,
int64_t Hkv,
int64_t D,
int64_t max_seq_len,
int64_t start_pos,
int64_t q_seq_len,
bool is_transposed) {
using executorch::cpublas::TransposeType;

const int64_t total_seqlen = start_pos + q_seq_len;
const float alpha = 1.0f / std::sqrt(static_cast<float>(D));
const int64_t heads_per_group = Hq / Hkv;
const int64_t hidden_size = Hq * D; // output row stride (ONNX convention)

// Input strides depend on layout
const int64_t ldq = is_transposed ? D : Hq * D;
const int64_t ldk = is_transposed ? D : Hkv * D;
const int64_t ldv = is_transposed ? D : Hkv * D;
// Output is always [B, S, Hq, D] so ldo = Hq * D = hidden_size
const int64_t ldo = hidden_size;

torch::executor::parallel_for(
0, batch * Hq, 1, [&](int64_t begin, int64_t end) {
for (int64_t idx = begin; idx < end; ++idx) {
const int64_t b = idx / Hq;
const int64_t h = idx % Hq;
const int64_t kv_h = h / heads_per_group;

const float* q_ptr;
const float* k_ptr;
const float* v_ptr;
if (is_transposed) {
q_ptr = q_data + (b * Hq + h) * q_seq_len * D;
k_ptr = k_data + (b * Hkv + kv_h) * max_seq_len * D;
v_ptr = v_data + (b * Hkv + kv_h) * max_seq_len * D;
} else {
q_ptr = q_data + b * q_seq_len * Hq * D + h * D;
k_ptr = k_data + b * max_seq_len * Hkv * D + kv_h * D;
v_ptr = v_data + b * max_seq_len * Hkv * D + kv_h * D;
}
// Output always [B, S, Hq, D]: head h writes at stride hidden_size
float* out_ptr = out_data + b * q_seq_len * hidden_size + h * D;

// Scores padded to max_seq_len columns (ONNX convention)
float* scores = scores_buf + idx * q_seq_len * max_seq_len;

// GEMM 1: Q @ K^T with scale in alpha
executorch::cpublas::gemm(
TransposeType::Transpose,
TransposeType::NoTranspose,
total_seqlen,
q_seq_len,
D,
alpha,
k_ptr,
ldk,
q_ptr,
ldq,
0.0f,
scores,
max_seq_len);

// Causal mask + narrow softmax (ONNX style):
// Zero future positions, softmax only on valid [0, causal_len).
for (int64_t qi = 0; qi < q_seq_len; ++qi) {
float* row = scores + qi * max_seq_len;
const int64_t causal_len =
std::min(start_pos + qi + 1, total_seqlen);

for (int64_t j = causal_len; j < total_seqlen; ++j) {
row[j] = 0.0f;
}

float max_val = row[0];
for (int64_t j = 1; j < causal_len; ++j) {
max_val = std::max(max_val, row[j]);
}
float sum = 0.0f;
for (int64_t j = 0; j < causal_len; ++j) {
row[j] = std::exp(row[j] - max_val);
sum += row[j];
}
const float inv_sum = 1.0f / sum;
for (int64_t j = 0; j < causal_len; ++j) {
row[j] *= inv_sum;
}
}

// GEMM 2: scores @ V -> output
executorch::cpublas::gemm(
TransposeType::NoTranspose,
TransposeType::NoTranspose,
D,
q_seq_len,
total_seqlen,
1.0f,
v_ptr,
ldv,
scores,
max_seq_len,
0.0f,
out_ptr,
ldo);
}
});
}

// Return max |a - b| across all elements.
float max_abs_diff(const Tensor& a, const Tensor& b) {
const float* a_data = a.const_data_ptr<float>();
const float* b_data = b.const_data_ptr<float>();
float max_abs_diff(const float* a, const float* b, int64_t n) {
float d = 0.0f;
for (int64_t i = 0; i < a.numel(); ++i) {
d = std::max(d, std::abs(a_data[i] - b_data[i]));
for (int64_t i = 0; i < n; ++i) {
d = std::max(d, std::abs(a[i] - b[i]));
}
return d;
}

float max_abs_diff(const Tensor& a, const Tensor& b) {
return max_abs_diff(
a.const_data_ptr<float>(), b.const_data_ptr<float>(), a.numel());
}

// Validate a single config: run StandardSDPA and custom_sdpa_out on the same
// inputs, check outputs match within tolerance. Returns false on mismatch.
// Only tests standard [B,S,H,D] layout (is_transposed=false).
Expand Down Expand Up @@ -268,6 +396,65 @@ bool validate_config(
(long)q_seq_len,
diff);

// Also validate ONNX GQA variant. Output is always [B, S, Hq, D].
// Since we only test standard [B,S,H,D] layout, out_ref is already
// [B,S,Hq,D] — just copy directly to ref_bshd (no transpose needed).
Tensor out_onnx =
tf.zeros({(int32_t)batch, (int32_t)q_seq_len, (int32_t)Hq, (int32_t)D});
std::vector<float> onnx_scores_buf(batch * Hq * q_seq_len * max_seq_len);
run_onnx_gqa_sdpa(
q.const_data_ptr<float>(),
k.const_data_ptr<float>(),
v.const_data_ptr<float>(),
out_onnx.mutable_data_ptr<float>(),
onnx_scores_buf.data(),
batch,
Hq,
Hkv,
D,
max_seq_len,
start_pos,
q_seq_len,
false /* is_transposed */);

// out_ref is already [B, S, Hq, D] (standard layout), compare directly
std::vector<float> ref_bshd(batch * q_seq_len * Hq * D);
const float* ref_ptr = out_ref.const_data_ptr<float>();
std::copy(ref_ptr, ref_ptr + batch * q_seq_len * Hq * D, ref_bshd.data());

float onnx_diff = max_abs_diff(
out_onnx.const_data_ptr<float>(),
ref_bshd.data(),
batch * q_seq_len * Hq * D);
if (onnx_diff > atol) {
fprintf(
stderr,
"FAIL: OnnxGQA standard %s (B=%ld Hq=%ld Hkv=%ld D=%ld sp=%ld sl=%ld) "
"max_abs_diff=%.6e > atol=%.6e\n",
mode,
(long)batch,
(long)Hq,
(long)Hkv,
(long)D,
(long)start_pos,
(long)q_seq_len,
onnx_diff,
atol);
return false;
}
fprintf(
stderr,
"PASS: OnnxGQA standard %s (B=%ld Hq=%ld Hkv=%ld D=%ld sp=%ld sl=%ld) "
"max_abs_diff=%.6e\n",
mode,
(long)batch,
(long)Hq,
(long)Hkv,
(long)D,
(long)start_pos,
(long)q_seq_len,
onnx_diff);

return true;
}

Expand Down Expand Up @@ -517,6 +704,132 @@ BENCHMARK_DEFINE_F(StandardSDPABenchFixture, StandardSDPA)
}
}

// ONNX Runtime GQA-style benchmark. Faithfully matches the algorithm from
// gqa_attention_base.h: scale-in-alpha, padded scores buffer, narrow softmax,
// and output in [B, S, Hq, D] with stride Hq*D.
class OnnxGQABenchFixture : public benchmark::Fixture {
public:
// Args: {batch, num_heads_q, num_heads_kv, head_dim, max_seq_len, start_pos,
// query_seq_len, is_transposed}
void SetUp(benchmark::State& state) override {
int64_t batch = state.range(0);
int64_t num_heads_q = state.range(1);
int64_t num_heads_kv = state.range(2);
int64_t head_dim = state.range(3);
int64_t max_seq_len = state.range(4);
int64_t start_pos = state.range(5);
int64_t q_seq_len = state.range(6);
bool is_transposed = state.range(7) != 0;

std::mt19937 gen(42);

if (is_transposed) {
q_.emplace(tf_.zeros(
{(int32_t)batch,
(int32_t)num_heads_q,
(int32_t)q_seq_len,
(int32_t)head_dim}));
k_cache_.emplace(tf_.zeros(
{(int32_t)batch,
(int32_t)num_heads_kv,
(int32_t)max_seq_len,
(int32_t)head_dim}));
v_cache_.emplace(tf_.zeros(
{(int32_t)batch,
(int32_t)num_heads_kv,
(int32_t)max_seq_len,
(int32_t)head_dim}));
} else {
q_.emplace(tf_.zeros(
{(int32_t)batch,
(int32_t)q_seq_len,
(int32_t)num_heads_q,
(int32_t)head_dim}));
k_cache_.emplace(tf_.zeros(
{(int32_t)batch,
(int32_t)max_seq_len,
(int32_t)num_heads_kv,
(int32_t)head_dim}));
v_cache_.emplace(tf_.zeros(
{(int32_t)batch,
(int32_t)max_seq_len,
(int32_t)num_heads_kv,
(int32_t)head_dim}));
}
// Output always [B, S, Hq, D] (ONNX convention)
output_.emplace(tf_.zeros(
{(int32_t)batch,
(int32_t)q_seq_len,
(int32_t)num_heads_q,
(int32_t)head_dim}));

fill_random(*q_, gen);
fill_random(*k_cache_, gen);
fill_random(*v_cache_, gen);

batch_ = batch;
num_heads_q_ = num_heads_q;
num_heads_kv_ = num_heads_kv;
head_dim_ = head_dim;
max_seq_len_ = max_seq_len;
start_pos_ = start_pos;
q_seq_len_ = q_seq_len;
is_transposed_ = is_transposed;

// Scores buffer padded to max_seq_len columns (ONNX convention)
int64_t total_units = batch * num_heads_q;
scores_buf_.resize(total_units * q_seq_len * max_seq_len);
}

void TearDown(benchmark::State&) override {
q_.reset();
k_cache_.reset();
v_cache_.reset();
output_.reset();
scores_buf_.clear();
}

TensorFactory<ScalarType::Float> tf_;
std::optional<Tensor> q_;
std::optional<Tensor> k_cache_;
std::optional<Tensor> v_cache_;
std::optional<Tensor> output_;
std::vector<float> scores_buf_;
int64_t batch_ = 0;
int64_t num_heads_q_ = 0;
int64_t num_heads_kv_ = 0;
int64_t head_dim_ = 0;
int64_t max_seq_len_ = 0;
int64_t start_pos_ = 0;
int64_t q_seq_len_ = 0;
bool is_transposed_ = false;
};

BENCHMARK_DEFINE_F(OnnxGQABenchFixture, OnnxGQA)
(benchmark::State& state) {
const float* q_data = q_->const_data_ptr<float>();
const float* k_data = k_cache_->const_data_ptr<float>();
const float* v_data = v_cache_->const_data_ptr<float>();
float* out_data = output_->mutable_data_ptr<float>();

for (auto _ : state) {
run_onnx_gqa_sdpa(
q_data,
k_data,
v_data,
out_data,
scores_buf_.data(),
batch_,
num_heads_q_,
num_heads_kv_,
head_dim_,
max_seq_len_,
start_pos_,
q_seq_len_,
is_transposed_);
}
}

/*
* Benchmark configurations modeled after Llama 3 8B (GQA: 32 q heads, 8 kv
* heads, head_dim=128). We test decode (seq_len=1) and prefill scenarios at
Expand Down Expand Up @@ -565,6 +878,33 @@ BENCHMARK_REGISTER_F(StandardSDPABenchFixture, StandardSDPA)
->Args({1, 32, 32, 128, 2048, 256, 1, 1})
->ArgNames({"B", "Hq", "Hkv", "D", "MaxS", "StartPos", "SeqLen", "Trans"});

// --- ONNX Runtime GQA-style SDPA ---
// Same configs as StandardSDPA. Differences: scale-in-alpha, padded scores
// buffer (ld=MaxS), narrow softmax, output in [B,S,Hq,D] with stride Hq*D.
BENCHMARK_REGISTER_F(OnnxGQABenchFixture, OnnxGQA)
// Standard layout decode at various cache positions
->Args({1, 32, 8, 128, 2048, 0, 1, 0})
->Args({1, 32, 8, 128, 2048, 64, 1, 0})
->Args({1, 32, 8, 128, 2048, 256, 1, 0})
->Args({1, 32, 8, 128, 2048, 512, 1, 0})
->Args({1, 32, 8, 128, 2048, 1024, 1, 0})
// Transposed layout decode at same positions
->Args({1, 32, 8, 128, 2048, 0, 1, 1})
->Args({1, 32, 8, 128, 2048, 64, 1, 1})
->Args({1, 32, 8, 128, 2048, 256, 1, 1})
->Args({1, 32, 8, 128, 2048, 512, 1, 1})
->Args({1, 32, 8, 128, 2048, 1024, 1, 1})
// Standard layout prefill
->Args({1, 32, 8, 128, 2048, 0, 128, 0})
->Args({1, 32, 8, 128, 2048, 0, 512, 0})
// Transposed layout prefill
->Args({1, 32, 8, 128, 2048, 0, 128, 1})
->Args({1, 32, 8, 128, 2048, 0, 512, 1})
// Llama 2 style (32 heads, no GQA)
->Args({1, 32, 32, 128, 2048, 256, 1, 0})
->Args({1, 32, 32, 128, 2048, 256, 1, 1})
->ArgNames({"B", "Hq", "Hkv", "D", "MaxS", "StartPos", "SeqLen", "Trans"});

} // namespace

int main(int argc, char** argv) {
Expand Down
Loading