You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
When using dist_checkpointing with multiple nodes, Checkpoint metadata 'metadata.json is only saved in the master node, rather than each node.
However, when loading using the exact same checkpointing. The load path requires the metadata to be accessible in all nodes.
if not Path(checkpoint_dir).exists():
raise CheckpointingException(f'Checkpoint directory {checkpoint_dir} does not exist')
saved_config = maybe_load_config(checkpoint_dir)
if saved_config is None:
raise CheckpointingException(f'{checkpoint_dir} is not a distributed checkpoint')
Looks like the current megatron-lm dist-checkpointing made an assumption that shared file system between multiple nodes are used. Or user requires the synchronize the metadata.
To Reproduce
Using a dummy save / load example below, and launch on multiple nodes.
# run the command below on two nodes
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
torchrun --nproc_per_node=8 \
--nnodes=2 \
--rdzv-backend=static \
--node_rank=$NODE_RANK \
--rdzv_id=nemo_$WORLD_SIZE \
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
dist_cp_save_load.py
# dist_cp_save_load.py
from pathlib import Path
import os
import torch
import shutil
import logging
from megatron.core import dist_checkpointing
from local_dev.utils import env_setup
def run():
# Setup the logging and relevant environment variables.
env_setup.set_print_env()
# TODO: make the path to be configurable.
ckpt_root = Path('/tmp/checkpoints')
print(f"Starting the Megatron-LM dist_checkpoint save_save_benchmark: work_size:{os.environ['WORLD_SIZE']}, num_node:{os.environ['NNODES']}, ckpt_root:{ckpt_root}")
if Path(ckpt_root).exists():
shutil.rmtree(ckpt_root, ignore_errors=True) # Deletes the entire directory and its contents
native_ckpt_root = ckpt_root / 'native'
native_ckpt_root.mkdir(exist_ok=True, parents=True)
dist_ckpt_root = ckpt_root / 'dist_ckpt'
dist_ckpt_root.mkdir(exist_ok=True, parents=True)
torch.distributed.init_process_group(backend="nccl", init_method="env://")
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
local_rank_idx=int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank_idx)
# Local tensor to save
assert 128 % world_size == 0
# 1GB data on global workers.
num_elems_per_rank = 1024 * 1024 * 1024 // world_size
local_ten = torch.arange(start=num_elems_per_rank * rank,
end=num_elems_per_rank * (rank + 1))
# Native python checkpoint save.
state_dict = {
'weight': local_ten
}
torch.save(state_dict, native_ckpt_root / f'ckpt_{rank}.pt')
print("Saved native checkpoint")
# Distributed checkpoint save
# `(0, rank, world_size)` describes that `weight` ShardedTensor is sharded into `world_size` pieces
# along the 0th dimension and `local_ten` is the shard at position `rank`.
# Together, all shards implicitly form a "global" `torch.arange(128)` tensor.
sharded_state_dict = {
'weight': dist_checkpointing.ShardedTensor.from_rank_offsets('weight', local_ten, (0, rank, world_size))
}
dist_checkpointing.save(sharded_state_dict, dist_ckpt_root)
print("Saved dist checkpoint")
print("Starting to load the dist checkpoint.")
local_ten_to_load = torch.empty(num_elems_per_rank)
sharded_state_dict = {
'weight': dist_checkpointing.ShardedTensor.from_rank_offsets('weight', local_ten_to_load, (0, rank, world_size))
}
loaded_state_dict = dist_checkpointing.load(sharded_state_dict, dist_ckpt_root)
expected_local_ten = torch.arange(start=num_elems_per_rank * rank, end=num_elems_per_rank * (rank + 1))
assert torch.all(loaded_state_dict['weight'] == expected_local_ten)
print("Loaded the disk checkpoint.")
torch.distributed.destroy_process_group()
if __name__ == '__main__':
run()
Expected behavior
A clear and concise description of what you expected to happen.
Stack trace/logs
[rank9]: Traceback (most recent call last):
[rank9]: File "/workspace/Megatron-LM/local_dev/mcore_checkpoint_benchmark/dist_save_load.py", line 78, in <module>
[rank9]: run()
[rank9]: File "/workspace/Megatron-LM/local_dev/mcore_checkpoint_benchmark/dist_save_load.py", line 69, in run
[rank9]: loaded_state_dict = dist_checkpointing.load(sharded_state_dict, dist_ckpt_root)
[rank9]: File "/workspace/Megatron-LM/megatron/core/dist_checkpointing/serialization.py", line 101, in load
[rank9]: sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(
[rank9]: File "/workspace/Megatron-LM/megatron/core/dist_checkpointing/validation.py", line 227, in verify_checkpoint_and_load_strategy
[rank9]: raise CheckpointingException(f'{checkpoint_dir} is not a distributed checkpoint')
[rank9]: megatron.core.dist_checkpointing.core.CheckpointingException: /tmp/checkpoints/dist_ckpt is not a distributed checkpoint
Uh oh!
There was an error while loading. Please reload this page.
Describe the bug
When using dist_checkpointing with multiple nodes, Checkpoint metadata
'metadata.json
is only saved in the master node, rather than each node.However, when loading using the exact same checkpointing. The load path requires the metadata to be accessible in all nodes.
Looks like the current megatron-lm dist-checkpointing made an assumption that shared file system between multiple nodes are used. Or user requires the synchronize the metadata.
To Reproduce
Using a dummy save / load example below, and launch on multiple nodes.
Expected behavior
A clear and concise description of what you expected to happen.
Stack trace/logs
the node 0(master node):
the node 1(worker node):
Proposed fix
Additional context
Add any other context about the problem here.
The text was updated successfully, but these errors were encountered: