Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions dlrover/python/elastic_agent/torch/ckpt_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ class CheckpointEvent:
global_shard_num: int = 0


@dataclass
class CheckpointNotifyEvent:
step: int = 0


@dataclass
class TensorMeta:
shape: Tuple[int] = None # type: ignore
Expand Down Expand Up @@ -451,6 +456,7 @@ def __init__(
self._latest_step = 0
qname = CheckpointSharedObjPrefix.SAVE_STEP_QNAME + str(0)
self._event_queue = SharedQueue(name=qname, create=True)
self._notify_queue = SharedQueue(name=qname + "_notify", create=True)
for i in range(self.local_shard_num):
self._shm_handlers.append(SharedMemoryHandler(i))
lock_name = CheckpointSharedObjPrefix.SHM_LOCK_NAME + str(i)
Expand Down Expand Up @@ -707,6 +713,7 @@ def _save_shard(
return False
finally:
shm_lock.release()
self._notify_queue.put(CheckpointNotifyEvent(step=step))

def _dist_make_dir(self, path, timeout=30):
if self._node_rank == 0:
Expand Down
40 changes: 40 additions & 0 deletions dlrover/trainer/tests/torch/checkpoint_egine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,46 @@ def test_sync_group(self):
finally:
dist.destroy_process_group()

def test_fast_save_memory(self):
engines = [
FullCheckpointEngine,
DeepSpeedCheckpointEngine,
]
for engine in engines:
self._test_fast_save_memory(engine)

def _test_fast_save_memory(self, engine_class):
model = SimpleNet()
state_dict = dict(
model=model.state_dict(),
step=100,
)
storage = PosixDiskStorage()
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_engine = engine_class(tmpdir, storage)
tmp = Path(tmpdir)
saved_file = tmp / "checkpoint-100/checkpoint.pt"
sd = {CheckpointConstant.MODEL_STATES_NAME: state_dict}
paths = {CheckpointConstant.MODEL_STATES_NAME: saved_file}
checkpoint_engine.save_to_storage(100, sd, paths)

# Simulate quick save_to_memory after save_to_storage.
# save_to_memory will wait for the async saving to complete,
# so no need to sleep here.
checkpoint_engine.save_to_memory(101, sd, paths)

# Check the tracker file and checkpoint, and the steps should
# be updated to 100 which is store by save_to_storage.
tracker_file = tmp / CheckpointConstant.TRACER_FILE_NAME
self.assertTrue(storage.exists(tracker_file))
self.assertEqual(tracker_file.read_text(), "100")
state = torch.load(saved_file)
self.assertEqual(state["step"], 100)

saver: AsyncCheckpointSaver = AsyncCheckpointSaver.get_ckpt_saver()
saver.close()
checkpoint_engine.close()


class PosixDiskStorageTest(unittest.TestCase):
def setUp(self):
Expand Down
29 changes: 29 additions & 0 deletions dlrover/trainer/tests/torch/fsdp_ckpt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,3 +419,32 @@ def test_fsdp_checkpointer(self):
self.assertListEqual(files, [".metadata", "__0_0.distcp"])
reader = checkpointer._engine.load(path)
self.assertTrue(isinstance(reader, SharedMemoryReader))

def test_fast_save_memory(self):
state_dict = {"step": 100}
storage = PosixDiskStorage()
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
path = tmpdir / str(100)
paths = {CheckpointConstant.MODEL_STATES_NAME: path}
engine = FsdpCheckpointEngine(tmpdir, storage)
engine.save_to_storage(100, state_dict, paths=paths)
self.assertEqual(engine._cached_step, 100)

# Simulate quick save_to_memory after save_to_storage.
# save_to_memory will wait for the async saving to complete,
# so no need to sleep here.
engine.save_to_memory(101, state_dict, paths)
self.assertEqual(engine._cached_step, 101)

# Check if the files are created correctly.
self.assertTrue(storage.exists(tmpdir / "._dlrover_ckpt_stage"))
self.assertTrue(storage.exists(tmpdir / "100/__0_0.distcp"))
# Check the tracker file, and the steps should be updated to 100
# which is store by save_to_storage.
tracker_file = tmpdir / CheckpointConstant.TRACER_FILE_NAME
self.assertTrue(storage.exists(tracker_file))
self.assertEqual(tracker_file.read_text(), "100")
##
reader = engine.load(path)
self.assertTrue(isinstance(reader, SharedMemoryReader))
5 changes: 5 additions & 0 deletions dlrover/trainer/torch/flash_checkpoint/deepspeed_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def save_to_memory(self, step, state_dict, paths):
["model_states", "optim_states"] of the state dict and
the value is the path of storage to save.
"""
if self._checkpoint_event_step > 0:
notify_event = self._notify_queue.get()
assert notify_event.step == self._checkpoint_event_step
self._checkpoint_event_step = -1
conf = CheckpointConfig(step=step, paths=paths)
success = self.save_state_dict_to_memory(state_dict, conf)
return success
Expand Down Expand Up @@ -120,6 +124,7 @@ def save_to_storage(self, step, state_dict, paths):
if self._local_rank == 0 and success:
event = CheckpointEvent(type=CheckpointEventType.SAVE, step=step)
self._event_queue.put(event)
self._checkpoint_event_step = step
if success:
self.latest_step = step
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All ranks should expect a notify to drain their local shard queue

        if self._local_rank == 0 and success:
            event = CheckpointEvent(type=CheckpointEventType.SAVE, step=step)
            self._event_queue.put(event)
        if success:
            self._checkpoint_event_step = step
            self.latest_step = step

return success
Expand Down
8 changes: 8 additions & 0 deletions dlrover/trainer/torch/flash_checkpoint/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,16 @@ def __init__(
name=CheckpointSharedObjPrefix.SAVE_STEP_QNAME + str(0),
create=False,
)
self._notify_queue = SharedQueue(
name=CheckpointSharedObjPrefix.SAVE_STEP_QNAME
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

implement this magic concatenation as a method of ckpt_saver

+ str(0)
+ "_notify",
create=False,
)
else:
self._event_queue = None # type: ignore
self._notify_queue = None # type: ignore
self._checkpoint_event_step = -1
self._update_saver_config()

# lock for shared memory
Expand Down
6 changes: 6 additions & 0 deletions dlrover/trainer/torch/flash_checkpoint/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,11 @@ def save_to_memory(self, step, state_dict, paths: Dict[str, str]):
if self._local_rank != self.local_shard_id:
return False

if self._checkpoint_event_step > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this part into parent class(engine)

notify_event = self._notify_queue.get()
assert notify_event.step == self._checkpoint_event_step
self._checkpoint_event_step = -1

acquired = self._shm_lock.acquire(blocking=False)
all_rank_ready = check_all_rank_ready(self._saver_group, acquired)
if not all_rank_ready:
Expand Down Expand Up @@ -548,6 +553,7 @@ def save_to_storage(self, step, state_dict, paths: Dict[str, str]):
)
event = CheckpointEvent(type=CheckpointEventType.SAVE, step=step)
self._event_queue.put(event)
self._checkpoint_event_step = step
if success:
self.latest_step = step

Expand Down
5 changes: 5 additions & 0 deletions dlrover/trainer/torch/flash_checkpoint/full_ckpt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ def save_to_memory(self, step, state_dict, paths: Dict[str, str]):
["model_states", "optim_states"] of the state dict and
the value is the path of storage to save.
"""
if self._checkpoint_event_step > 0:
notify_event = self._notify_queue.get()
assert notify_event.step == self._checkpoint_event_step
Copy link
Collaborator

@BalaBalaYi BalaBalaYi Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's ur intention if assert is false?
In my personal opinion, this situation should be handled softly, such as skipping this step and providing a log explanation. So the previous implementation that directly failed didn't have too many issues. At most, it was just not easy to understand why it was 'not equal' the log express.

self._checkpoint_event_step = -1
conf = CheckpointConfig(step=step, paths=paths)
return self.save_state_dict_to_memory(state_dict, conf)

Expand Down Expand Up @@ -140,6 +144,7 @@ def save_to_storage(self, step, state_dict, paths):
if success and self._rank == 0:
event = CheckpointEvent(type=CheckpointEventType.SAVE, step=step)
self._event_queue.put(event)
self._checkpoint_event_step = step
if success:
self.latest_step = step
return success
Expand Down
Loading