Add ONNX Runtime GQA-style SDPA benchmark#18647
Add ONNX Runtime GQA-style SDPA benchmark#18647kimishpatel wants to merge 17 commits intogh/kimishpatel/220/basefrom
Conversation
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: - Scale baked into GEMM alpha (no separate scaling pass) - Scores buffer padded to max_seq_len columns - Causal mask: zero out future positions, softmax on valid window only - Output always in [B, S, Hq, D] format Extends validation to verify ONNX GQA output matches custom_sdpa_out reference. Adds OnnxGQABenchFixture for benchmarking both layouts. Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18647
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 Cancelled Job, 5 Unrelated FailuresAs of commit fa4dd36 with merge base 1debeb6 ( CANCELLED JOB - The following job was cancelled. Please retry:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
digantdesai
left a comment
There was a problem hiding this comment.
Review automatically exported from Phabricator review in Meta.
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: - Scale baked into GEMM alpha (no separate scaling pass) - Scores buffer padded to max_seq_len columns - Causal mask: zero out future positions, softmax on valid window only - Output always in [B, S, Hq, D] format Extends validation to verify ONNX GQA output matches custom_sdpa_out reference. Adds OnnxGQABenchFixture for benchmarking both layouts. Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/) [ghstack-poisoned]
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: - Scale baked into GEMM alpha (no separate scaling pass) - Scores buffer padded to max_seq_len columns - Causal mask: zero out future positions, softmax on valid window only - Output always in [B, S, Hq, D] format Extends validation to verify ONNX GQA output matches custom_sdpa_out reference. Adds OnnxGQABenchFixture for benchmarking both layouts. Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/) [ghstack-poisoned]
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: - Scale baked into GEMM alpha (no separate scaling pass) - Scores buffer padded to max_seq_len columns - Causal mask: zero out future positions, softmax on valid window only - Output always in [B, S, Hq, D] format Extends validation to verify ONNX GQA output matches custom_sdpa_out reference. Adds OnnxGQABenchFixture for benchmarking both layouts. Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/) [ghstack-poisoned]
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: - Scale baked into GEMM alpha (no separate scaling pass) - Scores buffer padded to max_seq_len columns - Causal mask: zero out future positions, softmax on valid window only - Output always in [B, S, Hq, D] format Extends validation to verify ONNX GQA output matches custom_sdpa_out reference. Adds OnnxGQABenchFixture for benchmarking both layouts. Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/) [ghstack-poisoned]
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: - Scale baked into GEMM alpha (no separate scaling pass) - Scores buffer padded to max_seq_len columns - Causal mask: zero out future positions, softmax on valid window only - Output always in [B, S, Hq, D] format Extends validation to verify ONNX GQA output matches custom_sdpa_out reference. Adds OnnxGQABenchFixture for benchmarking both layouts. Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/) [ghstack-poisoned]
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: - Scale baked into GEMM alpha (no separate scaling pass) - Scores buffer padded to max_seq_len columns - Causal mask: zero out future positions, softmax on valid window only - Output always in [B, S, Hq, D] format Extends validation to verify ONNX GQA output matches custom_sdpa_out reference. Adds OnnxGQABenchFixture for benchmarking both layouts. Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/) [ghstack-poisoned]
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: - Scale baked into GEMM alpha (no separate scaling pass) - Scores buffer padded to max_seq_len columns - Causal mask: zero out future positions, softmax on valid window only - Output always in [B, S, Hq, D] format Extends validation to verify ONNX GQA output matches custom_sdpa_out reference. Adds OnnxGQABenchFixture for benchmarking both layouts. Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/) [ghstack-poisoned]
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: - Scale baked into GEMM alpha (no separate scaling pass) - Scores buffer padded to max_seq_len columns - Causal mask: zero out future positions, softmax on valid window only - Output always in [B, S, Hq, D] format Extends validation to verify ONNX GQA output matches custom_sdpa_out reference. Adds OnnxGQABenchFixture for benchmarking both layouts. Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/) [ghstack-poisoned]
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: - Scale baked into GEMM alpha (no separate scaling pass) - Scores buffer padded to max_seq_len columns - Causal mask: zero out future positions, softmax on valid window only - Output always in [B, S, Hq, D] format Extends validation to verify ONNX GQA output matches custom_sdpa_out reference. Adds OnnxGQABenchFixture for benchmarking both layouts. Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/) [ghstack-poisoned]
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: - Scale baked into GEMM alpha (no separate scaling pass) - Scores buffer padded to max_seq_len columns - Causal mask: zero out future positions, softmax on valid window only - Output always in [B, S, Hq, D] format Extends validation to verify ONNX GQA output matches custom_sdpa_out reference. Adds OnnxGQABenchFixture for benchmarking both layouts. Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/) [ghstack-poisoned]
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: - Scale baked into GEMM alpha (no separate scaling pass) - Scores buffer padded to max_seq_len columns - Causal mask: zero out future positions, softmax on valid window only - Output always in [B, S, Hq, D] format Extends validation to verify ONNX GQA output matches custom_sdpa_out reference. Adds OnnxGQABenchFixture for benchmarking both layouts. Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/) [ghstack-poisoned]
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: - Scale baked into GEMM alpha (no separate scaling pass) - Scores buffer padded to max_seq_len columns - Causal mask: zero out future positions, softmax on valid window only - Output always in [B, S, Hq, D] format Extends validation to verify ONNX GQA output matches custom_sdpa_out reference. Adds OnnxGQABenchFixture for benchmarking both layouts. Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/) [ghstack-poisoned]
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: - Scale baked into GEMM alpha (no separate scaling pass) - Scores buffer padded to max_seq_len columns - Causal mask: zero out future positions, softmax on valid window only - Output always in [B, S, Hq, D] format Extends validation to verify ONNX GQA output matches custom_sdpa_out reference. Adds OnnxGQABenchFixture for benchmarking both layouts. Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/) [ghstack-poisoned]
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: - Scale baked into GEMM alpha (no separate scaling pass) - Scores buffer padded to max_seq_len columns - Causal mask: zero out future positions, softmax on valid window only - Output always in [B, S, Hq, D] format Extends validation to verify ONNX GQA output matches custom_sdpa_out reference. Adds OnnxGQABenchFixture for benchmarking both layouts. Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/) [ghstack-poisoned]
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: - Scale baked into GEMM alpha (no separate scaling pass) - Scores buffer padded to max_seq_len columns - Causal mask: zero out future positions, softmax on valid window only - Output always in [B, S, Hq, D] format Extends validation to verify ONNX GQA output matches custom_sdpa_out reference. Adds OnnxGQABenchFixture for benchmarking both layouts. Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/) [ghstack-poisoned]
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: - Scale baked into GEMM alpha (no separate scaling pass) - Scores buffer padded to max_seq_len columns - Causal mask: zero out future positions, softmax on valid window only - Output always in [B, S, Hq, D] format Extends validation to verify ONNX GQA output matches custom_sdpa_out reference. Adds OnnxGQABenchFixture for benchmarking both layouts. Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/) [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from
onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h:
Extends validation to verify ONNX GQA output matches custom_sdpa_out
reference. Adds OnnxGQABenchFixture for benchmarking both layouts.
Differential Revision: D96044317