Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/prime_rl/trainer/sft/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,7 @@ def run_validation(step: int) -> None:
nan_loss_count = nan_loss_count.item()

logger.debug(f"Clipping gradients with max norm {config.optim.max_norm}")
grad_norm = clip_grad_norm_(
model.parameters(), max_norm=config.optim.max_norm, ep_enabled=parallel_dims.ep_enabled
)
grad_norm = clip_grad_norm_(model.parameters(), max_norm=config.optim.max_norm)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incomplete fix: RL trainer still uses broken ep_enabled

Medium Severity

The ep_enabled=parallel_dims.ep_enabled parameter was removed from clip_grad_norm_ in the SFT trainer because it "doesn't actually work," but the identical call in src/prime_rl/trainer/rl/train.py still passes ep_enabled=parallel_dims.ep_enabled. If the EP-enabled variant of clip_grad_norm_ is broken, the RL trainer has the same problem and likely needs the same fix.

Fix in Cursor Fix in Web

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo this will be broken with standard EP, so it should be just en_enabled=use_deepep and the same should be added to rl? Though I'd swear it worked before with deepep

if grad_norm.device.type == "cpu":
grad_norm = grad_norm.to(torch.device("cuda"))
zero_grad_ratio = get_zero_gradient_ratio(model.parameters(), parallel_dims.dp_replicate)
Expand Down
Loading