From f46b5d25dd8f674968142c5660054a23b570d800 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 1 Apr 2026 13:32:26 -0700 Subject: [PATCH] Add quantized input support to cpu_sdpa cpu_sdpa (unfused SDPA) previously only supported float inputs. When the model uses quantized Q/K/V (int8 with per-channel scales and zero_points), decode fell back to cpu_flash_attention, missing the ~25-30% throughput improvement from unfused SDPA. This adds quantized support to cpu_sdpa by: - Accepting optional quantization params (zero_points, scales for Q/K/V) - Using _q_at_k_gemm for Q@K^T (handles both int8 and float) - Using _qk_at_v_gemm for scores@V (handles both int8 and float) - Applying scaling factor separately (fused with mask add or max reduction) - Allocating a dequantization buffer for V when quantized The dispatch in op_sdpa.cpp is updated to route quantized decode (seq_len==1) through cpu_sdpa instead of cpu_flash_attention. Differential Revision: [D96044310](https://our.internmc.facebook.com/intern/diff/D96044310/) [ghstack-poisoned] --- .../llm/custom_ops/op_custom_sdpa_test.cpp | 188 +++++++++++++++++- extension/llm/custom_ops/op_sdpa.cpp | 8 +- extension/llm/custom_ops/op_sdpa_impl.h | 175 ++++++++++++---- 3 files changed, 328 insertions(+), 43 deletions(-) diff --git a/extension/llm/custom_ops/op_custom_sdpa_test.cpp b/extension/llm/custom_ops/op_custom_sdpa_test.cpp index 619afae7466..e61edb99fd3 100644 --- a/extension/llm/custom_ops/op_custom_sdpa_test.cpp +++ b/extension/llm/custom_ops/op_custom_sdpa_test.cpp @@ -7,8 +7,9 @@ */ // Tests for the unfused SDPA code path (cpu_sdpa) dispatched when -// seq_len == 1 and inputs are non-quantized (the decode fast-path). -// These call custom_sdpa_out directly, not through sdpa_with_kv_cache. +// seq_len == 1 (the decode fast-path). Covers both float and quantized +// inputs. These call custom_sdpa_out / custom_quantized_sdpa_out +// directly, not through sdpa_with_kv_cache. #include #include @@ -114,6 +115,55 @@ void compute_reference_sdpa( } } +/** + * Dequantize int8 tensor in [B, S, H, D] layout using per-token + * scales/zero_points in [B, S, H, 1] layout. + * dequant(x) = (x - zero_point) * scale + */ +void dequantize_per_token( + const int8_t* data, int B, int S, int H, int D, + const float* scales, + const int8_t* zps, + float* out) { + for (int b = 0; b < B; b++) { + for (int s = 0; s < S; s++) { + for (int h = 0; h < H; h++) { + int param_idx = b * S * H + s * H + h; + float sc = scales[param_idx]; + float zp = static_cast(zps[param_idx]); + for (int d = 0; d < D; d++) { + int idx = b * S * H * D + s * H * D + h * D + d; + out[idx] = (static_cast(data[idx]) - zp) * sc; + } + } + } + } +} + +// Helper: call custom_quantized_sdpa_out. Inputs use [B, S, H, D] layout. +executorch::aten::Tensor call_custom_quantized_sdpa( + const executorch::aten::Tensor& q, + const executorch::aten::Tensor& k, + const executorch::aten::Tensor& v, + int64_t start_pos, + const std::optional& attn_mask, + double dropout_p, + bool is_causal, + std::optional scale, + const std::optional& q_zp, + const std::optional& q_sc, + const std::optional& k_zp, + const std::optional& k_sc, + const std::optional& v_zp, + const std::optional& v_sc, + executorch::aten::Tensor& out) { + executorch::runtime::KernelRuntimeContext ctx{}; + return torch::executor::native::custom_quantized_sdpa_out( + ctx, q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, + q_zp, q_sc, k_zp, k_sc, v_zp, v_sc, + /*is_seq_at_dim_1=*/false, out); +} + } // namespace // With a single KV entry (start_pos=0), output must equal V[0]. @@ -263,3 +313,137 @@ TEST(OpCustomSdpaTest, DecodeCausalMatchesNonCausal) { EXPECT_TENSOR_CLOSE_WITH_TOL(out_c, out_nc, 1e-6, 1e-6); } + +// Quantized decode: int8 Q/K/V with per-token scales and zero_points, +// verified against dequantize-then-float-SDPA reference. +TEST(OpCustomSdpaTest, DecodeQuantized) { + TensorFactory tfChar; + TensorFactory tfFloat; + + // Q: [B=1, S=1, H=2, D=4] as int8 + auto q = tfChar.make( + {1, 1, 2, 4}, + {10, 20, -5, 15, -10, 5, 25, -20}); + + // K: [B=1, kv_len=3, H=2, D=4] as int8 + auto k = tfChar.make( + {1, 3, 2, 4}, + {8, -12, 18, 5, -3, 22, -8, 14, + 15, 7, -20, 10, 12, -15, 9, 6, + -5, 25, 3, -10, 20, 8, -12, 17}); + + // V: [B=1, kv_len=3, H=2, D=4] as int8 + auto v = tfChar.make( + {1, 3, 2, 4}, + {5, 15, -8, 20, 10, -5, 18, 12, + -12, 8, 22, -3, 7, 20, -10, 15, + 18, -5, 10, 3, -8, 12, 5, -20}); + + // Per-token scales [B, S/kv, H, 1] and zero_points [B, S/kv, H, 1] + auto q_sc = tfFloat.make({1, 1, 2, 1}, {0.05f, 0.05f}); + auto k_sc = tfFloat.make({1, 3, 2, 1}, + {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f}); + auto v_sc = tfFloat.make({1, 3, 2, 1}, + {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f}); + auto q_zp = tfChar.make({1, 1, 2, 1}, {0, 0}); + auto k_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0}); + auto v_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0}); + + int64_t start_pos = 2; + int num_valid = 3; + + // Dequantize and compute float reference + std::vector q_deq(8), k_deq(24), v_deq(24); + dequantize_per_token( + q.const_data_ptr(), 1, 1, 2, 4, + q_sc.const_data_ptr(), q_zp.const_data_ptr(), + q_deq.data()); + dequantize_per_token( + k.const_data_ptr(), 1, 3, 2, 4, + k_sc.const_data_ptr(), k_zp.const_data_ptr(), + k_deq.data()); + dequantize_per_token( + v.const_data_ptr(), 1, 3, 2, 4, + v_sc.const_data_ptr(), v_zp.const_data_ptr(), + v_deq.data()); + + std::vector ref(8, 0.0f); + compute_reference_sdpa( + q_deq.data(), 1, 1, 2, 4, + k_deq.data(), 3, 2, + v_deq.data(), + ref.data(), false, start_pos, num_valid); + + auto expected = tfFloat.make({1, 1, 2, 4}, ref); + auto out = tfFloat.zeros({1, 1, 2, 4}); + call_custom_quantized_sdpa( + q, k, v, start_pos, {}, 0.0, false, {}, + q_zp, q_sc, k_zp, k_sc, v_zp, v_sc, out); + EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-3, 1e-3); +} + +// Quantized GQA decode: 4 query heads sharing 2 KV heads, int8 inputs. +TEST(OpCustomSdpaTest, DecodeQuantizedGQA) { + TensorFactory tfChar; + TensorFactory tfFloat; + + // Q: [B=1, S=1, H_q=4, D=4] as int8 + auto q = tfChar.make( + {1, 1, 4, 4}, + {10, 20, -5, 15, -10, 5, 25, -20, + 8, -3, 12, 7, -15, 18, 4, -8}); + + // K: [B=1, kv_len=3, H_kv=2, D=4] as int8 + auto k = tfChar.make( + {1, 3, 2, 4}, + {8, -12, 18, 5, -3, 22, -8, 14, + 15, 7, -20, 10, 12, -15, 9, 6, + -5, 25, 3, -10, 20, 8, -12, 17}); + + // V: [B=1, kv_len=3, H_kv=2, D=4] as int8 + auto v = tfChar.make( + {1, 3, 2, 4}, + {5, 15, -8, 20, 10, -5, 18, 12, + -12, 8, 22, -3, 7, 20, -10, 15, + 18, -5, 10, 3, -8, 12, 5, -20}); + + auto q_sc = tfFloat.make({1, 1, 4, 1}, {0.05f, 0.05f, 0.05f, 0.05f}); + auto k_sc = tfFloat.make({1, 3, 2, 1}, + {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f}); + auto v_sc = tfFloat.make({1, 3, 2, 1}, + {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f}); + auto q_zp = tfChar.make({1, 1, 4, 1}, {0, 0, 0, 0}); + auto k_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0}); + auto v_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0}); + + int64_t start_pos = 2; + int num_valid = 3; + + std::vector q_deq(16), k_deq(24), v_deq(24); + dequantize_per_token( + q.const_data_ptr(), 1, 1, 4, 4, + q_sc.const_data_ptr(), q_zp.const_data_ptr(), + q_deq.data()); + dequantize_per_token( + k.const_data_ptr(), 1, 3, 2, 4, + k_sc.const_data_ptr(), k_zp.const_data_ptr(), + k_deq.data()); + dequantize_per_token( + v.const_data_ptr(), 1, 3, 2, 4, + v_sc.const_data_ptr(), v_zp.const_data_ptr(), + v_deq.data()); + + std::vector ref(16, 0.0f); + compute_reference_sdpa( + q_deq.data(), 1, 1, 4, 4, + k_deq.data(), 3, 2, + v_deq.data(), + ref.data(), false, start_pos, num_valid); + + auto expected = tfFloat.make({1, 1, 4, 4}, ref); + auto out = tfFloat.zeros({1, 1, 4, 4}); + call_custom_quantized_sdpa( + q, k, v, start_pos, {}, 0.0, false, {}, + q_zp, q_sc, k_zp, k_sc, v_zp, v_sc, out); + EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-3, 1e-3); +} diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 60144b92019..dae25cf702d 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -412,15 +412,17 @@ Tensor& custom_sdpa_out_impl( InvalidArgument, output); - bool use_unfused_sdpa = q.scalar_type() != ScalarType::Char && - seq_len == 1; + bool use_unfused_sdpa = seq_len == 1; if (use_unfused_sdpa) { ET_SWITCH_FLOAT_TYPES( output.scalar_type(), ctx, "sdpa", CTYPE, [&] { sdpa::impl::cpu_sdpa( ctx, output, q, k, v, is_causal, attn_mask, scale, seq_dim, - start_pos, num_keys_for_causal_attention); + start_pos, num_keys_for_causal_attention, + q_zero_points, q_scales, + k_zero_points, k_scales, + v_zero_points, v_scales); }); } else { ET_SWITCH_FLOAT_TYPES( diff --git a/extension/llm/custom_ops/op_sdpa_impl.h b/extension/llm/custom_ops/op_sdpa_impl.h index a54958faad7..2c3a4126311 100644 --- a/extension/llm/custom_ops/op_sdpa_impl.h +++ b/extension/llm/custom_ops/op_sdpa_impl.h @@ -1123,7 +1123,13 @@ void cpu_sdpa( const optional& scale, const SeqDim seq_dim, const int64_t start_pos, - const int64_t num_keys_for_causal_attention) { + const int64_t num_keys_for_causal_attention, + const optional& q_zero_points = nullopt, + const optional& q_scales = nullopt, + const optional& k_zero_points = nullopt, + const optional& k_scales = nullopt, + const optional& v_zero_points = nullopt, + const optional& v_scales = nullopt) { using accum_t = scalar_t; using Vec = vec::Vectorized; accum_t scaling_factor = static_cast(calculate_scale(query, scale)); @@ -1158,6 +1164,7 @@ void cpu_sdpa( int64_t num_reps = num_head / num_heads_kv; bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel(); + bool is_quantized_sdpa = query.scalar_type() == ScalarType::Char; // Extract strides, swapping seq/head dims based on seq_dim auto q_strides = query.strides(); @@ -1186,6 +1193,33 @@ void cpu_sdpa( mStrideM = m_strides[0]; } + int64_t q_quant_params_StrideB = 0; + int64_t q_quant_params_StrideH = 0; + int64_t q_quant_params_StrideM = 0; + int64_t k_quant_params_StrideB = 0; + int64_t k_quant_params_StrideH = 0; + int64_t k_quant_params_StrideN = 0; + int64_t v_quant_params_StrideB = 0; + int64_t v_quant_params_StrideH = 0; + int64_t v_quant_params_StrideN = 0; + + if (is_quantized_sdpa) { + auto q_qp_strides = q_zero_points.value().strides(); + q_quant_params_StrideB = q_qp_strides[0]; + q_quant_params_StrideH = (seq_dim == SeqDim::ONE) ? q_qp_strides[2] : q_qp_strides[1]; + q_quant_params_StrideM = (seq_dim == SeqDim::ONE) ? q_qp_strides[1] : q_qp_strides[2]; + + auto k_qp_strides = k_zero_points.value().strides(); + k_quant_params_StrideB = k_qp_strides[0]; + k_quant_params_StrideH = (seq_dim == SeqDim::ONE) ? k_qp_strides[2] : k_qp_strides[1]; + k_quant_params_StrideN = (seq_dim == SeqDim::ONE) ? k_qp_strides[1] : k_qp_strides[2]; + + auto v_qp_strides = v_zero_points.value().strides(); + v_quant_params_StrideB = v_qp_strides[0]; + v_quant_params_StrideH = (seq_dim == SeqDim::ONE) ? v_qp_strides[2] : v_qp_strides[1]; + v_quant_params_StrideN = (seq_dim == SeqDim::ONE) ? v_qp_strides[1] : v_qp_strides[2]; + } + // Allocate per-thread scores buffer: [qSize, kvSize] per (batch, head) #ifdef ET_USE_THREADPOOL int64_t num_thread = @@ -1207,6 +1241,24 @@ void cpu_sdpa( } accum_t* buf_data = reinterpret_cast(buf); + // Allocate dequantization buffer for V (used by _qk_at_v_gemm when m > 4) + int64_t size_per_thread_qdq_vec = kvSize * headSize; + std::unique_ptr allocated_buf_for_qdq; + accum_t* scratch_for_quant_dequant = nullptr; + if (is_quantized_sdpa) { + int64_t size_qdq_bytes = + size_per_thread_qdq_vec * num_thread * sizeof(accum_t); + Result scratch_qdq = ctx.allocate_temp(size_qdq_bytes, 64); + if (!scratch_qdq.ok()) { + allocated_buf_for_qdq = std::make_unique(size_qdq_bytes); + scratch_for_quant_dequant = + reinterpret_cast(allocated_buf_for_qdq.get()); + } else { + scratch_for_quant_dequant = + reinterpret_cast(scratch_qdq.get()); + } + } + const scalar_t* q_data = query.const_data_ptr(); const scalar_t* k_data = key.const_data_ptr(); const scalar_t* v_data = value.const_data_ptr(); @@ -1217,41 +1269,77 @@ void cpu_sdpa( auto compute_lambda = [&](int64_t begin, int64_t end) { int64_t ompIdx = torch::executor::get_thread_num(); accum_t* scores = buf_data + ompIdx * scores_per_thread; + accum_t* buf_qdq_ptr = is_quantized_sdpa + ? scratch_for_quant_dequant + ompIdx * size_per_thread_qdq_vec + : nullptr; for (int64_t idx = begin; idx < end; ++idx) { int64_t b = idx / num_head; int64_t h = idx % num_head; int64_t kv_h = h / num_reps; - // Pointer to Q[b, h, :, :] and K[b, kv_h, :, :] with appropriate strides - const scalar_t* q_ptr = q_data + b * qStrideB + h * qStrideH; - const scalar_t* k_ptr = k_data + b * kStrideB + kv_h * kStrideH; - const scalar_t* v_ptr = v_data + b * vStrideB + kv_h * vStrideH; + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const float* q_scales_ptr = nullptr; + const float* k_scales_ptr = nullptr; + const float* v_scales_ptr = nullptr; + const int8_t* q_zp_ptr = nullptr; + const int8_t* k_zp_ptr = nullptr; + const int8_t* v_zp_ptr = nullptr; + + int64_t q_offset = b * qStrideB + h * qStrideH; + int64_t k_offset = b * kStrideB + kv_h * kStrideH; + int64_t v_offset = b * vStrideB + kv_h * vStrideH; + + if (is_quantized_sdpa) { + q_ptr = reinterpret_cast(q_data) + q_offset; + k_ptr = reinterpret_cast(k_data) + k_offset; + v_ptr = reinterpret_cast(v_data) + v_offset; + + int64_t q_qp_offset = + b * q_quant_params_StrideB + h * q_quant_params_StrideH; + int64_t k_qp_offset = + b * k_quant_params_StrideB + kv_h * k_quant_params_StrideH; + int64_t v_qp_offset = + b * v_quant_params_StrideB + kv_h * v_quant_params_StrideH; + + q_scales_ptr = + q_scales.value().const_data_ptr() + q_qp_offset; + k_scales_ptr = + k_scales.value().const_data_ptr() + k_qp_offset; + v_scales_ptr = + v_scales.value().const_data_ptr() + v_qp_offset; + q_zp_ptr = + q_zero_points.value().const_data_ptr() + q_qp_offset; + k_zp_ptr = + k_zero_points.value().const_data_ptr() + k_qp_offset; + v_zp_ptr = + v_zero_points.value().const_data_ptr() + v_qp_offset; + } else { + q_ptr = q_data + q_offset; + k_ptr = k_data + k_offset; + v_ptr = v_data + v_offset; + } scalar_t* o_ptr = out_data + b * oStrideB + h * oStrideH; - // GEMM 1: scores[qSize, kvSize] = scaling_factor * Q[qSize, D] @ K^T[D, kvSize] - ::executorch::cpublas::gemm( - ::executorch::cpublas::TransposeType::Transpose, - ::executorch::cpublas::TransposeType::NoTranspose, - kvSize, qSize, headSize, - scaling_factor, - k_ptr, kStrideN, - q_ptr, qStrideM, - static_cast(0), - scores, kvSize); - - // Causal mask + attention mask + softmax per query row + // GEMM 1: scores[qSize, kvSize] = Q[qSize, D] @ K^T[D, kvSize] + MaybeQuantizedMatrixData q_matrix( + q_ptr, q_zp_ptr, q_scales_ptr, + qSize, headSize, q_quant_params_StrideM, query.scalar_type()); + MaybeQuantizedMatrixData k_matrix( + k_ptr, k_zp_ptr, k_scales_ptr, + kvSize, headSize, k_quant_params_StrideN, key.scalar_type()); + _q_at_k_gemm( + qSize, kvSize, headSize, + q_matrix, qStrideM, + k_matrix, kStrideN, + scores); + + // Causal mask + scaling + attention mask + softmax per query row for (int64_t qi = 0; qi < qSize; ++qi) { accum_t* row = scores + qi * kvSize; - // Apply attention mask if present - if (has_attn_mask) { - const accum_t* mask_row = mask_data + qi * mStrideM; - for (int64_t j = 0; j < kvSize; ++j) { - row[j] += mask_row[j]; - } - } - // Apply causal mask if (is_causal) { int64_t valid = std::min(start_pos + qi + 1, kvSize); @@ -1260,16 +1348,27 @@ void cpu_sdpa( } } - // Softmax: find max, compute exp, normalize - accum_t max_val = vec::reduce_all( - [](Vec& x, Vec& y) { return vec::maximum(x, y); }, - row, kvSize); + accum_t max_val; + const int kvSizeInt = static_cast(kvSize); + if (has_attn_mask) { + // Apply scaling factor and attention mask in fusion + const accum_t* mask_row = mask_data + qi * mStrideM; + for (int64_t j = 0; j < kvSize; ++j) { + row[j] = row[j] * scaling_factor + mask_row[j]; + } + max_val = vec::reduce_all( + [](Vec& x, Vec& y) { return vec::maximum(x, y); }, + row, kvSize); + } else { + // Apply scaling factor and find max in fusion + _mul_reduce_max_fusion_kernel( + row, scaling_factor, kvSizeInt, row, max_val); + } if (max_val == -std::numeric_limits::infinity()) { fill_stub(row, static_cast(0), kvSize); } else { accum_t sum_val = max_val; - const int kvSizeInt = static_cast(kvSize); _exp_reduce_sum_fusion_kernel(row, kvSizeInt, row, sum_val); accum_t inv_sum = static_cast(1) / sum_val; vec::map( @@ -1279,15 +1378,15 @@ void cpu_sdpa( } // GEMM 2: output[qSize, D] = scores[qSize, kvSize] @ V[kvSize, D] - ::executorch::cpublas::gemm( - ::executorch::cpublas::TransposeType::NoTranspose, - ::executorch::cpublas::TransposeType::NoTranspose, - headSize, qSize, kvSize, - static_cast(1), - v_ptr, vStrideN, + MaybeQuantizedMatrixData v_matrix( + v_ptr, v_zp_ptr, v_scales_ptr, + kvSize, headSize, v_quant_params_StrideN, value.scalar_type()); + _qk_at_v_gemm( + qSize, headSize, kvSize, scores, kvSize, - static_cast(0), - o_ptr, oStrideM); + v_matrix, vStrideN, + o_ptr, oStrideM, + static_cast(0), buf_qdq_ptr); } }; torch::executor::parallel_for(