[PyTorch] Propagate FP8 graph weight update flag in GroupedLinear#3052
[PyTorch] Propagate FP8 graph weight update flag in GroupedLinear#3052allenphilipj wants to merge 3 commits into
Conversation
937ef34 to
80304fa
Compare
Greptile SummaryThis PR fixes
Confidence Score: 5/5The change is a minimal, targeted fix — two lines are added to retrieve the graph-capture tensor and one None literal is replaced. The new code is structurally identical to the already-tested pattern in Linear, LayerNormLinear, and LayerNormMLP. The new code block is a direct copy of the well-established pattern used by the other three TE linear modules. The only variable affected flows unchanged into an existing parameter slot; no control flow, dtypes, or tensor shapes are altered. The fix is additive and isolated to the FP8 graph-capture path. No files require special attention; the single changed file is straightforward and self-contained. Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant GroupedLinear.forward
participant FP8GlobalStateManager
participant _GroupedLinear.forward
participant _prepare_weights_for_grouped_tensor_gemm
participant quantize_weight
Caller->>GroupedLinear.forward: forward(inp, m_splits, is_first_microbatch)
GroupedLinear.forward->>FP8GlobalStateManager: fp8_graph_capturing()?
alt FP8 graph capture active
FP8GlobalStateManager-->>GroupedLinear.forward: skip_fp8_weight_update tensor
GroupedLinear.forward->>GroupedLinear.forward: "is_first_microbatch = False"
else Normal execution
GroupedLinear.forward->>GroupedLinear.forward: "skip_fp8_weight_update = None"
end
GroupedLinear.forward->>GroupedLinear.forward: "cache_weight = (is_first_microbatch is not None)"
GroupedLinear.forward->>_GroupedLinear.forward: non_tensor_args (includes skip_fp8_weight_update)
_GroupedLinear.forward->>_prepare_weights_for_grouped_tensor_gemm: skip_fp8_weight_update
_prepare_weights_for_grouped_tensor_gemm->>quantize_weight: "skip_update_flag=skip_fp8_weight_update"
quantize_weight-->>_GroupedLinear.forward: (cached or freshly-quantized) FP8 weight
Reviews (12): Last reviewed commit: "Merge branch 'main' into codex-grouped-l..." | Re-trigger Greptile |
|
/te-ci pytorch |
d7a4caa to
1890acf
Compare
|
@ksivaman I've rebased on the latest main & resolved the conflicts, would much appreciate a follow-up review. |
ksivaman
left a comment
There was a problem hiding this comment.
This test is not necessary here, the change itself looks good.
Signed-off-by: allenphilipj <allenphilipj@users.noreply.github.com>
5fd52b1 to
e445000
Compare
|
@ksivaman I dropped the test per feedback. The PR now contains only the GroupedLinear skip_fp8_weight_update propagation change. |
|
Want your agent to iterate on Greptile's feedback? Try greploops. |
|
/te-ci pytorch |
Summary:
Validation:
Fixes #3051