You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
ROCm/aiter#3451 requires that LayerNorm weights are in fp32 for use in the DeepSeek v3.2 fused indexer, this is done to prevent casting weights from fp32 to bf16 which results in a loss of precision. This PR aligns with this by making the LayerNorm weights FP32 for the DeepSeek v3.2 ATOM pipeline.
Test Plan
Verify performance and accuracy before and after changes from this PR in combination with the companion aiter PR.
Thanks Carl! The overall direction looks right (storing k_norm params as fp32 so the fused indexer in ROCm/aiter#3451 gets fp32 weights), but the conversion block in LayerNorm.forward is currently a no-op, so a few things should be fixed before merge:
The converted tensors are never used. Both layernorm2d_fwd_ and layernorm2d_fwd_with_add_ are still called with self.weight/self.bias rather than the local weight/bias produced just above, so the new if branch has no effect:
self.dtype != weight compares a torch.dtype against a Tensor, so it isn't a dtype check — it should be weight.dtype.
Even after fixing (1)/(2), the conversion is a no-op, because weight is created with dtype=self.dtype in __init__, so weight.dtype == self.dtype always holds and .to(self.dtype) can never change anything. If the intent (per the discussion thread) is to downcast fp32 params to the kernel's compute dtype before the CK LayerNorm, the target should be the activation dtype, e.g. weight.to(x.dtype), not self.dtype.
The key question that determines the shape of this PR: does CK layernorm2d_fwd / layernorm2d_fwd_with_add accept fp32 weight/bias?
If yes, no conversion is needed — just store fp32 params and pass them through; please drop the dead branch entirely.
If no, we genuinely need to downcast to the compute dtype before the kernel, and (1)+(3) must be fixed so the kernel actually receives the downcast tensors.
Note this only matters on the unfused path (use_qk_rope_cache_fusion=False), where k = self.k_norm(k) routes fp32 weights into the CK kernel. The fused indexer path already passes self.k_norm.weight/.bias directly (fp32), which is exactly what #3451 expects, and #3451 only changes the dtype contract on indexer_qk_rope_quant_and_cache — it doesn't touch layernorm2d_fwd. So could you confirm which case applies, ideally with the eager/unfused path exercised (not just the fused indexer)?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
This is a companion PR to ROCm/aiter#3451.
ROCm/aiter#3451 requires that LayerNorm weights are in fp32 for use in the DeepSeek v3.2 fused indexer, this is done to prevent casting weights from fp32 to bf16 which results in a loss of precision. This PR aligns with this by making the LayerNorm weights FP32 for the DeepSeek v3.2 ATOM pipeline.
Test Plan
Verify performance and accuracy before and after changes from this PR in combination with the companion aiter PR.
Test Result
Performance comparison
Configuration: ISL=1000, OSL=100, CONC=4
Accuracy
FP32 LayerNorm
Main
Submission Checklist