Skip to content

Feat: deep-gemm-fp8#2166

Open
S1ro1 wants to merge 4 commits intomainfrom
deep-gemm-fp8
Open

Feat: deep-gemm-fp8#2166
S1ro1 wants to merge 4 commits intomainfrom
deep-gemm-fp8

Conversation

@S1ro1
Copy link
Copy Markdown
Collaborator

@S1ro1 S1ro1 commented Apr 1, 2026

Note

High Risk
High risk because it adds a new FP8 execution/training path with custom Triton kernels and DeepGEMM autograd ops, which can impact numerical correctness, performance, and GPU/runtime compatibility (SM90-only, CUDA library loading changes).

Overview
Adds an opt-in model.fp8 flag that enables FP8 training on Hopper GPUs by replacing nn.Linear layers with DeepGEMM-backed FP8 blockwise matmuls and by switching MoE expert computation from torch._grouped_mm to a new FP8 grouped GEMM implementation.

Introduces new Triton FP8 quantization/layout utilities (fp8_utils.py) and custom autograd functions (fp8_linear.py, fp8_grouped_gemm.py), plumbs the fp8 flag through model configs and multiple MoE model configs, and adds a TileLang stability workaround by preloading libcudart.so with RTLD_GLOBAL.

Updates packaging to use a local deep-gemm wheel (tools/wheels/...) instead of downloading it from GitHub, with corresponding uv.lock changes.

Written by Cursor Bugbot for commit f4a45b7. This will update automatically on new commits. Configure here.

Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Fix All in Cursor

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

description="Whether to use FP8 training via DeepGEMM. Replaces nn.Linear layers with FP8 blockwise linear "
"and uses FP8 grouped GEMM for MoE experts. Requires SM90 (Hopper) GPUs and model.impl='custom'.",
),
] = False
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHANGELOG not updated for new config field

Low Severity

A new fp8 config field was added to ModelConfig in src/prime_rl/configs/trainer.py, but CHANGELOG.md was not updated. Per the project rule, any PR that modifies configuration structures (added fields) in src/prime_rl/configs/*.py must include a corresponding CHANGELOG.md entry.

Fix in Cursor Fix in Web

Triggered by project rule: BugBot Instructions

self.block_size = block_size

def forward(self, x: torch.Tensor) -> torch.Tensor:
return _FP8BlockwiseMM.apply(x, self.weight, self.block_size, torch.bfloat16)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FP8 linear forward silently drops bias parameter

Medium Severity

Float8BlockwiseLinear.forward computes only the FP8 matmul without adding self.bias, but from_linear faithfully copies bias from the source nn.Linear. Any converted layer with a non-None bias will silently produce incorrect results. The from_linear method needs to either reject biased layers or forward needs to add the bias.

Additional Locations (1)
Fix in Cursor Fix in Web

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant