Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/prime_rl/trainer/ckpt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import bisect
import gc
import shutil
import time
import warnings
Expand Down Expand Up @@ -101,17 +102,29 @@ def load_state_dict(self, state_dict: dict[str, Any]):
)

# Re-initialize CPU offload wrappers after loading
has_cpu_offload = False
for opt in self.optimizers:
if isinstance(opt, CPUOffloadOptimizer):
opt._move_states("cpu")
opt._initialized = True
has_cpu_offload = True

if self.scheduler is not None:
self.scheduler.load_state_dict(state_dict["scheduler"])
if self.progress is not None:
for key, value in state_dict["progress"].items():
setattr(self.progress, key, value)

# Reclaim GPU memory freed by moving optimizer states to CPU.
# After set_state_dict + _move_states("cpu"), the optimizer states live on CPU,
# but the state_dict (owned by dcp_load) still holds references to stale GPU
# optimizer tensors. Clearing them and flushing the CUDA cache prevents OOM on
# the first training step.
if has_cpu_offload:
state_dict.clear() # drop stale GPU tensor references from dcp_load
gc.collect() # break any circular references so tensors are freed
torch.cuda.empty_cache() # return freed GPU memory to CUDA


class CheckpointManager:
"""Utility class to save and load trainer checkpoints to resume SFT and RL training."""
Expand Down
Loading