Skip to content

feat: add dtype option in LayerNorm class#1359

Open
cpersson-amd wants to merge 2 commits into
ROCm:mainfrom
cpersson-amd:ds32_fp32_layernorm
Open

feat: add dtype option in LayerNorm class#1359
cpersson-amd wants to merge 2 commits into
ROCm:mainfrom
cpersson-amd:ds32_fp32_layernorm

Conversation

@cpersson-amd

Copy link
Copy Markdown

Motivation

This is a companion PR to ROCm/aiter#3451.

ROCm/aiter#3451 requires that LayerNorm weights are in fp32 for use in the DeepSeek v3.2 fused indexer, this is done to prevent casting weights from fp32 to bf16 which results in a loss of precision. This PR aligns with this by making the LayerNorm weights FP32 for the DeepSeek v3.2 ATOM pipeline.

Test Plan

Verify performance and accuracy before and after changes from this PR in combination with the companion aiter PR.

Test Result

Performance comparison

Configuration: ISL=1000, OSL=100, CONC=4

Metric FP32 LayerNorm Main Δ (%)
Time to First Token (TTFT, ms) 229.31 216.36 -5.6%
Time per Output Token (TPOT, ms) 13.08 13.25 +1.3%
Inter-token Latency (ITL, ms) 13.08 13.25 +1.3%
End-to-end Latency (E2EL, ms) 1524.64 1528.34 +0.2%

Accuracy

FP32 LayerNorm

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9568 ± 0.0056
strict-match 5 exact_match 0.9560 ± 0.0056

Main

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9598 ± 0.0054
strict-match 5 exact_match 0.9591 ± 0.0055

Submission Checklist

@ChuanLi1101

Copy link
Copy Markdown
Collaborator

Thanks Carl! The overall direction looks right (storing k_norm params as fp32 so the fused indexer in ROCm/aiter#3451 gets fp32 weights), but the conversion block in LayerNorm.forward is currently a no-op, so a few things should be fixed before merge:

  1. The converted tensors are never used. Both layernorm2d_fwd_ and layernorm2d_fwd_with_add_ are still called with self.weight/self.bias rather than the local weight/bias produced just above, so the new if branch has no effect:

    weight, bias = self.weight, self.bias
    if self.dtype is not None and self.dtype != weight:
        weight, bias = weight.to(self.dtype), bias.to(self.dtype)
    if residual is None:
        return layernorm2d_fwd_(x, self.weight, self.bias, self.eps, self.dim)   # <- uses self.weight/self.bias
    else:
        return layernorm2d_fwd_with_add_(x, self.weight, residual, self.bias, self.eps, self.dim)  # <- same

    These should pass weight/bias.

  2. self.dtype != weight compares a torch.dtype against a Tensor, so it isn't a dtype check — it should be weight.dtype.

  3. Even after fixing (1)/(2), the conversion is a no-op, because weight is created with dtype=self.dtype in __init__, so weight.dtype == self.dtype always holds and .to(self.dtype) can never change anything. If the intent (per the discussion thread) is to downcast fp32 params to the kernel's compute dtype before the CK LayerNorm, the target should be the activation dtype, e.g. weight.to(x.dtype), not self.dtype.

The key question that determines the shape of this PR: does CK layernorm2d_fwd / layernorm2d_fwd_with_add accept fp32 weight/bias?

  • If yes, no conversion is needed — just store fp32 params and pass them through; please drop the dead branch entirely.
  • If no, we genuinely need to downcast to the compute dtype before the kernel, and (1)+(3) must be fixed so the kernel actually receives the downcast tensors.

Note this only matters on the unfused path (use_qk_rope_cache_fusion=False), where k = self.k_norm(k) routes fp32 weights into the CK kernel. The fused indexer path already passes self.k_norm.weight/.bias directly (fp32), which is exactly what #3451 expects, and #3451 only changes the dtype contract on indexer_qk_rope_quant_and_cache — it doesn't touch layernorm2d_fwd. So could you confirm which case applies, ideally with the eager/unfused path exercised (not just the fused indexer)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants