Skip to content

[BUG] dist_checkpointing metadata is only saved in master node for multi-node training. #1530

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
felixwqp opened this issue Apr 12, 2025 · 0 comments · May be fixed by #1531
Open

[BUG] dist_checkpointing metadata is only saved in master node for multi-node training. #1530

felixwqp opened this issue Apr 12, 2025 · 0 comments · May be fixed by #1531

Comments

@felixwqp
Copy link

felixwqp commented Apr 12, 2025

Describe the bug
When using dist_checkpointing with multiple nodes, Checkpoint metadata 'metadata.json is only saved in the master node, rather than each node.

  def metadata_finalize_fn():
       if torch.distributed.get_rank() == 0:
           save_config(
               CheckpointingConfig(sharded_strategy.backend, sharded_strategy.version),
               checkpoint_dir,
           )
       torch.distributed.barrier()

However, when loading using the exact same checkpointing. The load path requires the metadata to be accessible in all nodes.

    if not Path(checkpoint_dir).exists():
        raise CheckpointingException(f'Checkpoint directory {checkpoint_dir} does not exist')

    saved_config = maybe_load_config(checkpoint_dir)
    if saved_config is None:
        raise CheckpointingException(f'{checkpoint_dir} is not a distributed checkpoint')

Looks like the current megatron-lm dist-checkpointing made an assumption that shared file system between multiple nodes are used. Or user requires the synchronize the metadata.

To Reproduce

Using a dummy save / load example below, and launch on multiple nodes.

# run the command below on two nodes
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
torchrun --nproc_per_node=8 \
  --nnodes=2 \
  --rdzv-backend=static \
  --node_rank=$NODE_RANK \ 
  --rdzv_id=nemo_$WORLD_SIZE \
  --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
dist_cp_save_load.py
# dist_cp_save_load.py
from pathlib import Path

import os
import torch
import shutil
import logging

from megatron.core import dist_checkpointing

from local_dev.utils import env_setup



def run():
    # Setup the logging and relevant environment variables.
    env_setup.set_print_env()


    # TODO: make the path to be configurable.
    ckpt_root = Path('/tmp/checkpoints')

    print(f"Starting the Megatron-LM dist_checkpoint save_save_benchmark: work_size:{os.environ['WORLD_SIZE']}, num_node:{os.environ['NNODES']}, ckpt_root:{ckpt_root}")
    if Path(ckpt_root).exists():
        shutil.rmtree(ckpt_root, ignore_errors=True)  # Deletes the entire directory and its contents

    native_ckpt_root = ckpt_root / 'native'
    native_ckpt_root.mkdir(exist_ok=True, parents=True)
    dist_ckpt_root = ckpt_root / 'dist_ckpt'
    dist_ckpt_root.mkdir(exist_ok=True, parents=True)

    torch.distributed.init_process_group(backend="nccl", init_method="env://")
    world_size = torch.distributed.get_world_size()
    rank = torch.distributed.get_rank()
    local_rank_idx=int(os.environ['LOCAL_RANK'])
    torch.cuda.set_device(local_rank_idx)

    # Local tensor to save
    assert 128 % world_size == 0

    # 1GB data on global workers.
    num_elems_per_rank = 1024 * 1024 * 1024 // world_size
    local_ten = torch.arange(start=num_elems_per_rank * rank,
                            end=num_elems_per_rank * (rank + 1))

    # Native python checkpoint save.
    state_dict = {
        'weight': local_ten
    }
    torch.save(state_dict, native_ckpt_root / f'ckpt_{rank}.pt')

    print("Saved native checkpoint")

    # Distributed checkpoint save
    # `(0, rank, world_size)` describes that `weight` ShardedTensor is sharded into `world_size` pieces
    # along the 0th dimension and `local_ten` is the shard at position `rank`.
    # Together, all shards implicitly form a "global" `torch.arange(128)` tensor.
    sharded_state_dict = {
        'weight': dist_checkpointing.ShardedTensor.from_rank_offsets('weight', local_ten, (0, rank, world_size))
    }

    dist_checkpointing.save(sharded_state_dict, dist_ckpt_root)
    print("Saved dist checkpoint")

    print("Starting to load the dist checkpoint.")
    local_ten_to_load = torch.empty(num_elems_per_rank)
    sharded_state_dict = {
        'weight': dist_checkpointing.ShardedTensor.from_rank_offsets('weight', local_ten_to_load, (0, rank, world_size))
    }
    loaded_state_dict = dist_checkpointing.load(sharded_state_dict, dist_ckpt_root)
    expected_local_ten = torch.arange(start=num_elems_per_rank * rank, end=num_elems_per_rank * (rank + 1))
    assert torch.all(loaded_state_dict['weight'] == expected_local_ten)
    print("Loaded the disk checkpoint.")

    torch.distributed.destroy_process_group()


if __name__ == '__main__':
    run()

Expected behavior
A clear and concise description of what you expected to happen.

Stack trace/logs

[rank9]: Traceback (most recent call last):
[rank9]:   File "/workspace/Megatron-LM/local_dev/mcore_checkpoint_benchmark/dist_save_load.py", line 78, in <module>
[rank9]:     run()
[rank9]:   File "/workspace/Megatron-LM/local_dev/mcore_checkpoint_benchmark/dist_save_load.py", line 69, in run
[rank9]:     loaded_state_dict = dist_checkpointing.load(sharded_state_dict, dist_ckpt_root)
[rank9]:   File "/workspace/Megatron-LM/megatron/core/dist_checkpointing/serialization.py", line 101, in load
[rank9]:     sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(
[rank9]:   File "/workspace/Megatron-LM/megatron/core/dist_checkpointing/validation.py", line 227, in verify_checkpoint_and_load_strategy
[rank9]:     raise CheckpointingException(f'{checkpoint_dir} is not a distributed checkpoint')
[rank9]: megatron.core.dist_checkpointing.core.CheckpointingException: /tmp/checkpoints/dist_ckpt is not a distributed checkpoint

the node 0(master node):

root@vm1:/workspace# ls -lt /tmp/checkpoints/dist_ckpt/
total 4194344
-rw-r--r-- 1 root root       119 Apr 12 07:46 metadata.json
-rw-r--r-- 1 root root 536872092 Apr 12 07:46 __6_0.distcp
-rw-r--r-- 1 root root 536872092 Apr 12 07:46 __5_0.distcp
-rw-r--r-- 1 root root 536872092 Apr 12 07:46 __2_0.distcp
-rw-r--r-- 1 root root 536872092 Apr 12 07:46 __7_0.distcp
-rw-r--r-- 1 root root 536872092 Apr 12 07:46 __0_0.distcp
-rw-r--r-- 1 root root 536872092 Apr 12 07:46 __1_0.distcp
-rw-r--r-- 1 root root 536872092 Apr 12 07:46 __4_0.distcp
-rw-r--r-- 1 root root 536872092 Apr 12 07:46 __3_0.distcp
-rw-r--r-- 1 root root       860 Apr 12 07:46 common.pt

the node 1(worker node):

root@vm2:/workspace# ls -lt /tmp/checkpoints/dist_ckpt/
total 4194336
-rw-r--r-- 1 root root 536872092 Apr 12 07:46 __11_0.distcp
-rw-r--r-- 1 root root 536872092 Apr 12 07:46 __14_0.distcp
-rw-r--r-- 1 root root 536872092 Apr 12 07:46 __9_0.distcp
-rw-r--r-- 1 root root 536872092 Apr 12 07:46 __13_0.distcp
-rw-r--r-- 1 root root 536872092 Apr 12 07:46 __15_0.distcp
-rw-r--r-- 1 root root 536872092 Apr 12 07:46 __10_0.distcp
-rw-r--r-- 1 root root 536872092 Apr 12 07:46 __8_0.distcp
-rw-r--r-- 1 root root 536872092 Apr 12 07:46 __12_0.distcp

Proposed fix

  • instead of saving only on the node with "global_rank == 0", save on rank with "local_rank == 0"
  def metadata_finalize_fn():
       if os.get.env("LOCAL_RANK") == 0:
           save_config(
               CheckpointingConfig(sharded_strategy.backend, sharded_strategy.version),
               checkpoint_dir,
           )
       torch.distributed.barrier()

Additional context
Add any other context about the problem here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
1 participant