From 804ebb5bb97ee9a9460037879b7993a3be7adbe3 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Mon, 6 Apr 2026 08:47:23 -0700 Subject: [PATCH] Add contig requirement to aot binding for llm custom ops Differential Revision: [D93870393](https://our.internmc.facebook.com/intern/diff/D93870393/) [ghstack-poisoned] --- extension/llm/custom_ops/op_sdpa_aot.cpp | 31 ++++++++++++++++-------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/extension/llm/custom_ops/op_sdpa_aot.cpp b/extension/llm/custom_ops/op_sdpa_aot.cpp index e50b3707d51..8ec0ab40a65 100644 --- a/extension/llm/custom_ops/op_sdpa_aot.cpp +++ b/extension/llm/custom_ops/op_sdpa_aot.cpp @@ -183,11 +183,11 @@ Tensor& sdpa_with_kv_cache_out_no_context( } at::Tensor sdpa_with_kv_cache_aten( - const at::Tensor& q_projected, - const at::Tensor& k_projected, - const at::Tensor& v_projected, - at::Tensor& key_cache, - at::Tensor& value_cache, + const at::Tensor& q_proj, + const at::Tensor& k_proj, + const at::Tensor& v_proj, + at::Tensor& k_cache, + at::Tensor& v_cache, const int64_t start_pos, const int64_t seq_len, // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue @@ -197,6 +197,11 @@ at::Tensor sdpa_with_kv_cache_aten( const bool is_causal, // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale) { + auto q_projected = q_proj.contiguous(); + auto k_projected = k_proj.contiguous(); + auto v_projected = v_proj.contiguous(); + auto key_cache = k_cache.contiguous(); + auto value_cache = v_cache.contiguous(); auto output = at::empty_like(q_projected); WRAP_TO_ATEN(sdpa_with_kv_cache_out_no_context, 11) (q_projected, @@ -256,11 +261,14 @@ at::Tensor custom_sdpa_aten( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale, const bool is_seq_dim_2) { - auto output = at::empty(q.sizes()); + auto q_projected = q.contiguous(); + auto k_projected = k.contiguous(); + auto v_projected = v.contiguous(); + auto output = at::empty_like(q_projected); WRAP_TO_ATEN(custom_sdpa_out_no_context, 9) - (q, - k, - v, + (q_projected, + k_projected, + v_projected, start_pos, attn_mask, dropout_p, @@ -331,7 +339,10 @@ at::Tensor custom_quantized_sdpa_aten( const std::optional& v_zero_points, const std::optional& v_scales, const bool is_seq_at_dim_2) { - auto output = at::empty(q.sizes()); + auto q_projected = q.contiguous(); + auto k_projected = k.contiguous(); + auto v_projected = v.contiguous(); + auto output = at::empty(q_projected.sizes()); WRAP_TO_ATEN(custom_quantized_sdpa_out_no_context, 15) (q, k,