diff --git a/src/llama_cookbook/utils/config_utils.py b/src/llama_cookbook/utils/config_utils.py index eb4510bb7..4f5871bf0 100644 --- a/src/llama_cookbook/utils/config_utils.py +++ b/src/llama_cookbook/utils/config_utils.py @@ -6,6 +6,7 @@ import torch.distributed as dist from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType +from torch.distributed.fsdp import ShardingStrategy from torch.utils.data import DistributedSampler from peft import ( LoraConfig, @@ -119,4 +120,6 @@ def check_fsdp_config(fsdp_config): if not fsdp_config.checkpoint_type in VALID_TYPES: raise ValueError(f"Invalid checkpoint_type {fsdp_config.checkpoint_type}") - \ No newline at end of file + + if isinstance(fsdp_config.sharding_strategy, str): + fsdp_config.sharding_strategy=ShardingStrategy[fsdp_config.sharding_strategy]