[hipDNN] Add engine-agnostic SDPA forward integration test (ALMIOPEN-2127)#8824
Open
DarylHawkinsAMD wants to merge 7 commits into
Open
[hipDNN] Add engine-agnostic SDPA forward integration test (ALMIOPEN-2127)#8824DarylHawkinsAMD wants to merge 7 commits into
DarylHawkinsAMD wants to merge 7 commits into
Conversation
…ion test (ALMIOPEN-2127)
…cases Narrow the engine-agnostic SDPA forward suite to the real shipping surface (bf16 -> AITER ASM; bottom-right causal, the variant gfx942 ships) and drop the fp16 suite and top-left causal case, which had no engine and only SKIPped. Add MQA, non-square bottom-right causal, and an off-tile remainder seqlen.
…>kv, large) Cover distinct codepaths the base matrix misses: single partial tile (64), causal on an off-tile seqlen (mask+remainder interaction), bottom-right causal with seqQ>seqKv (opposite anchoring boundary), and a 2048 multi-tile seqlen for accumulation numerics. All 10 cases PASS on gfx942 (DUT=AITER ASM, ref=CPU).
The large multi-tile case spent ~32s in the O(n^2) CPU reference, dominating suite wall time. Revisit a large-seqlen case once the GPU reference (#8438) makes the reference cheap.
- Remove STATUS.md (ephemeral PR-scratch doc); durable rationale (engine- intersection scope, D-remainder out of scope, CI enablement deferred to #8438) moved into the test file as comments. - Add mha_peaked_softmax (large attn scale) so the streaming cross-KV-tile running-max correction is actually exercised; default [-1,1] inputs give a near-flat softmax that hid that path. - Add mha_bshd_layout (non-packed BSHD strides) to exercise the kernel stride/addressing path, not just contiguous BHSD. AITER ASM accepts it (PASS). - Fix the inverted causal_br_q_gt_kv comment: first seqQ-seqKv query rows are fully masked; case validates the empty-row (0, not NaN) convention. - Add a gtest name generator so cases surface by note (e.g. .../mha_peaked_softmax) instead of /N. - Annotate the unused stats binding; reword the headsKv field comment; trim unused includes (PlatformUtils, CpuFpReferenceValidation) to the headers used. All 11 cases PASS on gfx942 (DUT=AITER ASM, ref=CPU).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds an engine-agnostic SDPA forward-pass integration test built on the shared
IntegrationGraphVerificationHarness: it constructs a frontendGraph::sdpa()graph and validates the device engine's output against the reference graph executor (CPU today; GPU once #8438 lands). This establishes the engine-neutral / GPU-reference vehicle for SDPA forward correctness, complementing the existing provider-local AITER ASM forward test.Risk Assessment
Low risk. Test-only, non-shipping change: a new translation unit gated by
HIPDNN_ENABLE_SDPAplus a singleSdpaFwdNodetolerance branch in the integration-test harness. It adds no product behavior, public API, schema, or build-default change. Residual risk is coverage breadth, not correctness: the suite is not yet bound to a CI lane (enablement rides with #8438), and only gfx942 has been validated so far.ASIC Coverage
No multi-arch sweep is meaningful for this change. The cases execute only where an SDPA forward engine exists (AITER ASM on gfx942/gfx950); on any other target the graph has no supporting engine and the cases
GTEST_SKIP. The suite is also not yet wired to an SDPA-capable ctest target, so PR CI compiles the TU and the cases SKIP in the only lane that runs the sharedhipdnn_integration_testsexe (miopen-provider) — passing PR CI does not exercise them. Engine-relevant verification is therefore gfx942 (done) and gfx950 (pending); end-to-end CI enablement across those arches lands with #8438.Testing Summary
sdpa::getToleranceFwd<bf16>(1e-2).Testing Checklist
cmake --build build --target hipdnn_integration_tests- Status: Passed./build/bin/hipdnn_integration_tests --reference-executor cpu --gtest_filter='*IntegrationGpuSdpaFwd*'(withHIPDNN_AITER_ASM_DIRset to the in-tree AITER ASM kernels) - ASICs: gfx942 (MI300A, ROCm 7.14) - Status: Passedgit commit- Status: PassedTechnical Changes
dnn-providers/integration-tests/src/integration_tests/sdpa/IntegrationGpuSdpaFwdInference.cpp: anSdpaForward<DataType>harness subclass that builds Q/K/V (BHSD, with one BSHD-layout case), drivesGraph::sdpa()with the modern non-deprecated causal form (set_diagonal_band_right_bound+DiagonalAlignment::BOTTOM_RIGHT) andattn_scale = 1/sqrt(d), and validates the output tensor against the reference executor. A gtest name generator surfaces cases by name.SdpaFwdNode -> sdpa::getToleranceFwd<T>()tolerance branch (and the matching include), reusing the existing tolerance-dispatch extension point.CMakeLists.txt.nhead_qa multiple of 8); rationale and the deferral of CI enablement to #8438 are documented in-file. The whole TU is gated byHIPDNN_ENABLE_SDPA.