diff --git a/examples/gr00t_n1_5/README.md b/examples/gr00t_n1_5/README.md index cc5074cc21..641d96c980 100644 --- a/examples/gr00t_n1_5/README.md +++ b/examples/gr00t_n1_5/README.md @@ -28,7 +28,7 @@ cd FlagScale/ pip install ".[cuda-train]" --verbose ``` -Install additional dependencies for downloading models/datasets: +Install additional dependencies for downloading datasets: ```sh # For HuggingFace Hub @@ -38,33 +38,6 @@ pip install huggingface_hub pip install modelscope ``` -## Download Models - -Download the pretrained GR00T N1.5 model using the provided script. Choose either HuggingFace Hub or ModelScope: - -**Using HuggingFace Hub:** - -```sh -cd FlagScale/ -python examples/pi0/download.py \ - --repo_id nvidia/GR00T-N1.5-3B \ - --output_dir /workspace/models \ - --source huggingface -``` - -**Using ModelScope:** - -```sh -cd FlagScale/ -python examples/pi0/download.py \ - --repo_id nvidia/GR00T-N1.5-3B \ - --output_dir /workspace/models \ - --source modelscope -``` - -The model will be downloaded to (example with `/workspace/models`): -- `/workspace/models/nvidia/GR00T-N1.5-3B` - ## Training ### Prepare Dataset diff --git a/examples/gr00t_n1_5/conf/train.yaml b/examples/gr00t_n1_5/conf/train.yaml index 89ef63badb..a0fda6bf4c 100644 --- a/examples/gr00t_n1_5/conf/train.yaml +++ b/examples/gr00t_n1_5/conf/train.yaml @@ -26,6 +26,7 @@ experiment: CUDA_DEVICE_MAX_CONNECTIONS: 1 WANDB_MODE: offline OTEL_SDK_DISABLED: true + HF_ENDPOINT: "https://hf-mirror.com" action: run diff --git a/examples/gr00t_n1_5/conf/train/gr00t_n1_5.yaml b/examples/gr00t_n1_5/conf/train/gr00t_n1_5.yaml index 0cfd4cc9d8..27e9f5dad4 100644 --- a/examples/gr00t_n1_5/conf/train/gr00t_n1_5.yaml +++ b/examples/gr00t_n1_5/conf/train/gr00t_n1_5.yaml @@ -17,7 +17,7 @@ system: model: model_name: gr00t_n1_5 # Path or HuggingFace model ID for the pretrained GR00T N1.5 model - checkpoint_dir: /workspace/models/nvidia/GR00T-N1.5-3B + checkpoint_dir: nvidia/GR00T-N1.5-3B # Fine-tuning control tune_llm: true diff --git a/examples/pi0/README.md b/examples/pi0/README.md index 0e82befbe3..604d9eab8b 100644 --- a/examples/pi0/README.md +++ b/examples/pi0/README.md @@ -30,7 +30,7 @@ pip install ".[cuda]" --verbose pip install git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi ``` -Install additional dependencies for downloading models/datasets: +Install additional dependencies for downloading datasets: ```sh # For HuggingFace Hub @@ -40,45 +40,6 @@ pip install huggingface_hub pip install modelscope ``` -## Download Models and Tokenizers - -Download models and tokenizers using the provided script. Choose either HuggingFace Hub or ModelScope based on your preference: - -**Using HuggingFace Hub:** - -```sh -cd FlagScale/ -python examples/pi0/download.py \ - --repo_id lerobot/pi0_base \ - --output_dir /workspace/models \ - --source huggingface - -python examples/pi0/download.py \ - --repo_id google/paligemma-3b-pt-224 \ - --output_dir /workspace/models \ - --source huggingface -``` - -**Using ModelScope:** - -```sh -cd FlagScale/ -python examples/pi0/download.py \ - --repo_id lerobot/pi0_base \ - --output_dir /workspace/models \ - --source modelscope - -python examples/pi0/download.py \ - --repo_id google/paligemma-3b-pt-224 \ - --output_dir /workspace/models \ - --source modelscope -``` - -The models will be downloaded to (example with `/workspace/models`): -- `/workspace/models/lerobot/pi0_base` -- `/workspace/models/google/paligemma-3b-pt-224` - - ## Training ### Prepare Dataset diff --git a/examples/pi0/conf/train.yaml b/examples/pi0/conf/train.yaml index 859bc50631..6c465da442 100644 --- a/examples/pi0/conf/train.yaml +++ b/examples/pi0/conf/train.yaml @@ -26,6 +26,7 @@ experiment: CUDA_DEVICE_MAX_CONNECTIONS: 1 WANDB_MODE: offline OTEL_SDK_DISABLED: true + FLAGSCALE_USE_MODELSCOPE: true action: run diff --git a/examples/pi0/conf/train/pi0.yaml b/examples/pi0/conf/train/pi0.yaml index 29b8e2b51a..cced331a46 100644 --- a/examples/pi0/conf/train/pi0.yaml +++ b/examples/pi0/conf/train/pi0.yaml @@ -18,9 +18,9 @@ system: model: model_name: pi0 # Path to the pretrained pi0_base model checkpoint - checkpoint_dir: /workspace/models/lerobot/pi0_base + checkpoint_dir: lerobot/pi0_base # Path to paligemma tokenizer - tokenizer_path: /workspace/models/google/paligemma-3b-pt-224 + tokenizer_path: google/paligemma-3b-pt-224 tokenizer_max_length: 48 optimizer: diff --git a/examples/pi0_5/README.md b/examples/pi0_5/README.md index cbe06d3adb..cdd8473ae4 100644 --- a/examples/pi0_5/README.md +++ b/examples/pi0_5/README.md @@ -30,7 +30,7 @@ pip install ".[cuda]" --verbose pip install git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi ``` -Install additional dependencies for downloading models/datasets: +Install additional dependencies for downloading datasets: ```sh # For HuggingFace Hub @@ -40,44 +40,6 @@ pip install huggingface_hub pip install modelscope ``` -## Download Models and Tokenizers - -Download models and tokenizers using the provided script. Choose either HuggingFace Hub or ModelScope based on your preference: - -**Using HuggingFace Hub:** - -```sh -cd FlagScale/ -python examples/pi0/download.py \ - --repo_id lerobot/pi05_base \ - --output_dir /workspace/models \ - --source huggingface - -python examples/pi0/download.py \ - --repo_id google/paligemma-3b-pt-224 \ - --output_dir /workspace/models \ - --source huggingface -``` - -**Using ModelScope:** - -```sh -cd FlagScale/ -python examples/pi0/download.py \ - --repo_id lerobot/pi05_base \ - --output_dir /workspace/models \ - --source modelscope - -python examples/pi0/download.py \ - --repo_id google/paligemma-3b-pt-224 \ - --output_dir /workspace/models \ - --source modelscope -``` - -The models will be downloaded to (example with `/workspace/models`): -- `/workspace/models/lerobot/pi05_base` -- `/workspace/models/google/paligemma-3b-pt-224` - ## Training ### Prepare Dataset diff --git a/examples/pi0_5/conf/train.yaml b/examples/pi0_5/conf/train.yaml index 60e88123b4..c2144c8132 100644 --- a/examples/pi0_5/conf/train.yaml +++ b/examples/pi0_5/conf/train.yaml @@ -26,6 +26,7 @@ experiment: CUDA_DEVICE_MAX_CONNECTIONS: 1 WANDB_MODE: offline OTEL_SDK_DISABLED: true + FLAGSCALE_USE_MODELSCOPE: true action: run diff --git a/examples/pi0_5/conf/train/pi0_5.yaml b/examples/pi0_5/conf/train/pi0_5.yaml index a33721daa3..e38c32574b 100644 --- a/examples/pi0_5/conf/train/pi0_5.yaml +++ b/examples/pi0_5/conf/train/pi0_5.yaml @@ -18,9 +18,9 @@ system: model: model_name: pi0.5 # Path to the pretrained pi05_base model checkpoint - checkpoint_dir: /workspace/models/lerobot/pi05_libero_base + checkpoint_dir: lerobot/pi05_libero_base # Path to paligemma tokenizer - tokenizer_path: /workspace/models/google/paligemma-3b-pt-224 + tokenizer_path: google/paligemma-3b-pt-224 tokenizer_max_length: 200 gradient_checkpointing: true freeze_vision_encoder: false diff --git a/examples/qwen_gr00t/README.md b/examples/qwen_gr00t/README.md index 46e36a51a2..e5a89b5323 100644 --- a/examples/qwen_gr00t/README.md +++ b/examples/qwen_gr00t/README.md @@ -28,7 +28,7 @@ cd FlagScale/ pip install ".[cuda-train]" --verbose ``` -Install additional dependencies for downloading models/datasets: +Install additional dependencies for downloading datasets: ```sh # For HuggingFace Hub @@ -38,34 +38,6 @@ pip install huggingface_hub pip install modelscope ``` -## Download Models - -Download the base VLM model. Qwen-GR00T supports Qwen3-VL and Qwen2.5-VL as the VLM backbone: - -**Using HuggingFace Hub:** - -```sh -cd FlagScale/ -python examples/pi0/download.py \ - --repo_id Qwen/Qwen3-VL-4B-Instruct \ - --output_dir /workspace/models \ - --source huggingface -``` - -**Using ModelScope:** - -```sh -cd FlagScale/ -python examples/pi0/download.py \ - --repo_id Qwen/Qwen3-VL-4B-Instruct \ - --output_dir /workspace/models \ - --source modelscope -``` - -The model will be downloaded to (example with `/workspace/models`): -- `/workspace/models/Qwen/Qwen3-VL-4B-Instruct` - - ## Training ### Prepare Dataset diff --git a/examples/qwen_gr00t/conf/train.yaml b/examples/qwen_gr00t/conf/train.yaml index 1b98e3c9b6..a075c3780e 100644 --- a/examples/qwen_gr00t/conf/train.yaml +++ b/examples/qwen_gr00t/conf/train.yaml @@ -26,6 +26,7 @@ experiment: CUDA_DEVICE_MAX_CONNECTIONS: 1 WANDB_MODE: offline OTEL_SDK_DISABLED: true + FLAGSCALE_USE_MODELSCOPE: true action: run diff --git a/examples/qwen_gr00t/conf/train/qwen_gr00t.yaml b/examples/qwen_gr00t/conf/train/qwen_gr00t.yaml index 85ba6bd351..fb87ae531b 100644 --- a/examples/qwen_gr00t/conf/train/qwen_gr00t.yaml +++ b/examples/qwen_gr00t/conf/train/qwen_gr00t.yaml @@ -23,7 +23,7 @@ model: model_name: qwen_gr00t vlm: type: qwen3-vl - base_vlm: /workspace/models/Qwen/Qwen3-VL-4B-Instruct/ + base_vlm: Qwen/Qwen3-VL-4B-Instruct attn_implementation: flash_attention_2 action_model: # Whether to condition the action model on proprioceptive state (observation.state) diff --git a/flagscale/models/pi0/configuration_pi0.py b/flagscale/models/pi0/configuration_pi0.py index 8fd902b4b3..c646bfe0be 100644 --- a/flagscale/models/pi0/configuration_pi0.py +++ b/flagscale/models/pi0/configuration_pi0.py @@ -24,6 +24,7 @@ from flagscale.models.configs.types import FeatureType, NormalizationMode, PolicyFeature from flagscale.models.utils.constants import OBS_IMAGES +from flagscale.models.utils.hub_utils import resolve_model_path from flagscale.models.vla.pretrained_config import PreTrainedConfig DEFAULT_IMAGE_SIZE = 224 @@ -185,6 +186,7 @@ def _from_dict(cls, data: dict[str, Any]) -> "PI0Config": @classmethod def from_pretrained(cls, config_dir: str, **kwargs: Any) -> "PI0Config": + config_dir = resolve_model_path(config_dir) config_path = os.path.join(config_dir, "config.json") if not os.path.exists(config_path): raise ValueError(f"config.json not found in {config_dir}") diff --git a/flagscale/models/pi05/configuration_pi05.py b/flagscale/models/pi05/configuration_pi05.py index 79b1dd3c0d..607c89973c 100644 --- a/flagscale/models/pi05/configuration_pi05.py +++ b/flagscale/models/pi05/configuration_pi05.py @@ -25,6 +25,7 @@ import draccus from flagscale.models.configs.types import FeatureType, NormalizationMode, PolicyFeature +from flagscale.models.utils.hub_utils import resolve_model_path from flagscale.models.vla.pretrained_config import PreTrainedConfig DEFAULT_IMAGE_SIZE = 224 @@ -170,6 +171,7 @@ def _from_dict(cls, data: dict[str, Any]) -> "PI05Config": @classmethod def from_pretrained(cls, config_dir: str, **kwargs: Any) -> "PI05Config": + config_dir = resolve_model_path(config_dir) config_path = os.path.join(config_dir, "config.json") if not os.path.exists(config_path): raise ValueError(f"config.json not found in {config_dir}") diff --git a/flagscale/models/utils/hub_utils.py b/flagscale/models/utils/hub_utils.py new file mode 100644 index 0000000000..294ee784b9 --- /dev/null +++ b/flagscale/models/utils/hub_utils.py @@ -0,0 +1,88 @@ +import hashlib +import os +import tempfile +from pathlib import Path + +import filelock + +from flagscale.logger import logger + +_lock_dir = tempfile.gettempdir() + + +def use_modelscope() -> bool: + return os.environ.get("FLAGSCALE_USE_MODELSCOPE", "false").lower() == "true" + + +# Copied from https://github.com/vllm-project/vllm/blob/1fc69f59bb0838c2ff6efc416dd8875c3e210d04/vllm/model_executor/model_loader/weight_utils.py +def _get_lock(model_name_or_path: str, cache_dir: str | None = None) -> filelock.FileLock: + lock_dir = cache_dir or _lock_dir + os.makedirs(lock_dir, exist_ok=True) + model_name = str(model_name_or_path).replace("/", "-") + hash_name = hashlib.sha256(model_name.encode()).hexdigest()[:16] + # add hash to avoid conflict with old users' lock files + lock_file = os.path.join(lock_dir, f"{hash_name}-{model_name}.lock") + # mode 0o666 is required for the filelock to be shared across users + return filelock.FileLock(lock_file, mode=0o666) + + +def resolve_model_path( + model_name_or_path: str, + revision: str | None = None, + cache_dir: str | None = None, + allow_patterns: list[str] | str | None = None, + ignore_patterns: list[str] | str | None = None, +) -> str: + """Resolve a model name or HF/ModelScope repo ID to a local directory path. + + If ``model_name_or_path`` is already a local directory, returns it as-is. + Otherwise downloads the model repo and returns the local cache path. + + The download backend is selected by the ``FLAGSCALE_USE_MODELSCOPE`` env var: + - ``false`` (default): uses ``huggingface_hub.snapshot_download`` + - ``true``: uses ``modelscope.hub.snapshot_download`` + + When ModelScope is enabled, ``HF_HUB_OFFLINE=1`` is set automatically so that + downstream HuggingFace calls (AutoTokenizer, cached_file, etc.) do not attempt + to reach huggingface.co. + """ + if use_modelscope(): + os.environ.setdefault("HF_HUB_OFFLINE", "1") + + if Path(model_name_or_path).is_dir(): + logger.info(f"Model path is local directory: {model_name_or_path}") + return model_name_or_path + + with _get_lock(model_name_or_path, cache_dir): + if use_modelscope(): + logger.info(f"Downloading model from ModelScope: {model_name_or_path}") + from modelscope.hub.snapshot_download import snapshot_download + + local_path = snapshot_download( + model_id=model_name_or_path, + cache_dir=cache_dir, + revision=revision, + ignore_file_pattern=ignore_patterns, + allow_patterns=allow_patterns, + ) + else: + logger.info(f"Downloading model from HuggingFace Hub: {model_name_or_path}") + from huggingface_hub import snapshot_download + + local_path = snapshot_download( + model_name_or_path, + repo_type="model", + revision=revision, + cache_dir=cache_dir, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + + if not local_path: + raise RuntimeError( + f"Failed to download model '{model_name_or_path}': " + "snapshot_download returned an empty path" + ) + + logger.info(f"Model resolved to: {local_path}") + return local_path diff --git a/flagscale/models/vla/base_policy.py b/flagscale/models/vla/base_policy.py index 2d2eb02656..64187e787e 100644 --- a/flagscale/models/vla/base_policy.py +++ b/flagscale/models/vla/base_policy.py @@ -86,6 +86,7 @@ def from_config(cls, config: PreTrainedConfig) -> TrainablePolicy: f"No policy registered for config type '{type_name}'. " f"Known policies: {list(cls._registry.keys())}" ) + config.resolve_pretrained_paths() return policy_cls(config=config) def save_pretrained(self, save_directory, *, state_dict=None) -> None: diff --git a/flagscale/models/vla/gr00t_n1_5/configuration_gr00t_n1_5.py b/flagscale/models/vla/gr00t_n1_5/configuration_gr00t_n1_5.py index cf57e25df4..d960461cfb 100644 --- a/flagscale/models/vla/gr00t_n1_5/configuration_gr00t_n1_5.py +++ b/flagscale/models/vla/gr00t_n1_5/configuration_gr00t_n1_5.py @@ -5,6 +5,7 @@ from flagscale.models.configs.types import FeatureType, NormalizationMode, PolicyFeature from flagscale.models.utils.constants import ACTION, OBS_STATE +from flagscale.models.utils.hub_utils import resolve_model_path from flagscale.models.vla.pretrained_config import PreTrainedConfig if TYPE_CHECKING: @@ -87,6 +88,9 @@ def validate_features(self) -> None: f"Action dimension {action_dim} exceeds max_action_dim {self.max_action_dim}." ) + def resolve_pretrained_paths(self) -> None: + self.base_model_path = resolve_model_path(self.base_model_path) + @classmethod def from_train_config(cls, train_config: TrainConfig) -> Gr00tN15Config: model_cfg = train_config.model diff --git a/flagscale/models/vla/gr00t_n1_5/gr00t_n1.py b/flagscale/models/vla/gr00t_n1_5/gr00t_n1.py index 1800fe4d43..b9d21cd56d 100644 --- a/flagscale/models/vla/gr00t_n1_5/gr00t_n1.py +++ b/flagscale/models/vla/gr00t_n1_5/gr00t_n1.py @@ -20,8 +20,7 @@ import torch import torch.nn.functional as F -from huggingface_hub import hf_hub_download, snapshot_download -from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError +from huggingface_hub import hf_hub_download from torch import nn from torch.distributions import Beta from transformers import PretrainedConfig, PreTrainedModel @@ -696,17 +695,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): print(f"Tune action head projector: {tune_projector}") print(f"Tune action head DiT: {tune_diffusion_model}") - # get the current model path being downloaded - try: - # NOTE(YL) This downloads the model to the local cache and returns the local path to the model - # saved in ~/.cache/huggingface/hub/ - local_model_path = snapshot_download(pretrained_model_name_or_path, repo_type="model") - # HFValidationError, RepositoryNotFoundError - except (HFValidationError, RepositoryNotFoundError): - print( - f"Model not found or avail in the huggingface hub. Loading from local path: {pretrained_model_name_or_path}" - ) - local_model_path = pretrained_model_name_or_path + local_model_path = pretrained_model_name_or_path pretrained_model = super().from_pretrained( local_model_path, local_model_path=local_model_path, **kwargs diff --git a/flagscale/models/vla/pretrained_config.py b/flagscale/models/vla/pretrained_config.py index 6a495b98e6..12198e8b38 100644 --- a/flagscale/models/vla/pretrained_config.py +++ b/flagscale/models/vla/pretrained_config.py @@ -159,6 +159,14 @@ def observation_delta_indices(self) -> list | None: def action_delta_indices(self) -> list | None: raise NotImplementedError + def resolve_pretrained_paths(self) -> None: + """Resolve any HF/ModelScope repo IDs in this config to local paths. + + Subclasses override to call ``resolve_model_path()`` on their + model-path fields. Called by ``TrainablePolicy.from_config()`` + before model construction. + """ + @classmethod def from_train_config(cls, train_config: TrainConfig): """Build a config from the OmegaConf-based TrainConfig. diff --git a/flagscale/models/vla/qwen_gr00t/configuration_qwen_gr00t.py b/flagscale/models/vla/qwen_gr00t/configuration_qwen_gr00t.py index c90b6c68bd..12c4083177 100644 --- a/flagscale/models/vla/qwen_gr00t/configuration_qwen_gr00t.py +++ b/flagscale/models/vla/qwen_gr00t/configuration_qwen_gr00t.py @@ -6,6 +6,7 @@ from flagscale.models.configs.types import NormalizationMode from flagscale.models.utils.constants import ACTION +from flagscale.models.utils.hub_utils import resolve_model_path from flagscale.models.vla.action_model.gr00t_action_header import GR00TActionHeadConfig from flagscale.models.vla.pretrained_config import PreTrainedConfig from flagscale.models.vla.vlm.qwenvl_backbone import QwenVLConfig @@ -46,6 +47,10 @@ def validate_features(self) -> None: if action_ft is None: raise ValueError(f"output_features must contain '{ACTION}' with type ACTION") + def resolve_pretrained_paths(self) -> None: + if self.vlm.base_vlm: + self.vlm.base_vlm = resolve_model_path(self.vlm.base_vlm) + @classmethod def from_train_config(cls, train_config: TrainConfig) -> QwenGr00tConfig: model_cfg = train_config.model diff --git a/flagscale/train/train_pi.py b/flagscale/train/train_pi.py index bc73786ce7..402c83f58c 100644 --- a/flagscale/train/train_pi.py +++ b/flagscale/train/train_pi.py @@ -47,6 +47,7 @@ from flagscale.models.configs.types import PolicyFeature from flagscale.models.utils.constants import ACTION, OBS_PREFIX, REWARD from flagscale.models.configs.types import FeatureType +from flagscale.models.utils.hub_utils import resolve_model_path from flagscale.models.pi0.configuration_pi0 import PI0Config from flagscale.models.pi0.modeling_pi0 import PI0Policy from flagscale.models.pi05.configuration_pi05 import PI05Config @@ -74,7 +75,7 @@ def set_seed(seed: int): if get_platform().name() == "cuda": torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = False + torch.backends.cudnn.deterministic = False torch.backends.cuda.matmul.allow_tf32 = False @@ -486,11 +487,21 @@ def main(config: TrainConfig, seed: int): f"Invalid model_name: {model_name}. Must be 'pi0' or 'pi0.5'" ) + # Resolve repo IDs to local paths up front. + # Note: config.model uses __getattr__ that reads from `raw` (OmegaConf) first, + # so attribute assignment doesn't stick. Use local variables instead. + tokenizer_path = resolve_model_path( + config.model.tokenizer_path, + ignore_patterns=["*.safetensors", "*.bin", "*.pt", "*.gguf"], + ) + + checkpoint_dir = resolve_model_path(config.model.checkpoint_dir) + # Load base config from checkpoint if model_name == "pi0.5": - policy_config = PI05Config.from_pretrained(config.model.checkpoint_dir) + policy_config = PI05Config.from_pretrained(checkpoint_dir) else: - policy_config = PI0Config.from_pretrained(config.model.checkpoint_dir) + policy_config = PI0Config.from_pretrained(checkpoint_dir) # Override with any model-specific fields from YAML model_config_overrides = config.model.get_model_config_dict() @@ -501,7 +512,7 @@ def main(config: TrainConfig, seed: int): logger.warning(f"Model config field '{key}' not found in {model_name} config, ignoring") # Set training-specific fields - policy_config.pretrained_path = config.model.checkpoint_dir + policy_config.pretrained_path = checkpoint_dir policy_config.use_amp = config.system.use_amp local_rank = init_distributed() @@ -554,7 +565,7 @@ def main(config: TrainConfig, seed: int): }, "norm_map": policy.config.normalization_mapping, }, - "tokenizer_processor": {"tokenizer_name": config.model.tokenizer_path}, + "tokenizer_processor": {"tokenizer_name": tokenizer_path}, } if rename_map is not None: diff --git a/tests/unit_tests/models/utils/test_hub_utils.py b/tests/unit_tests/models/utils/test_hub_utils.py new file mode 100644 index 0000000000..60ed8a41e6 --- /dev/null +++ b/tests/unit_tests/models/utils/test_hub_utils.py @@ -0,0 +1,104 @@ +import os +import tempfile +import unittest +from unittest.mock import MagicMock, patch + +from flagscale.models.utils.hub_utils import _get_lock, resolve_model_path, use_modelscope + + +class TestUseModelscope(unittest.TestCase): + def test_default_is_false(self): + with patch.dict(os.environ, {}, clear=True): + self.assertFalse(use_modelscope()) + + def test_true(self): + with patch.dict(os.environ, {"FLAGSCALE_USE_MODELSCOPE": "true"}): + self.assertTrue(use_modelscope()) + + def test_true_case_insensitive(self): + with patch.dict(os.environ, {"FLAGSCALE_USE_MODELSCOPE": "True"}): + self.assertTrue(use_modelscope()) + + def test_false_explicit(self): + with patch.dict(os.environ, {"FLAGSCALE_USE_MODELSCOPE": "false"}): + self.assertFalse(use_modelscope()) + + def test_other_value_is_false(self): + with patch.dict(os.environ, {"FLAGSCALE_USE_MODELSCOPE": "1"}): + self.assertFalse(use_modelscope()) + + +class TestGetLock(unittest.TestCase): + def test_lock_file_created(self): + with tempfile.TemporaryDirectory() as tmpdir: + lock = _get_lock("org/model", cache_dir=tmpdir) + self.assertTrue(lock.lock_file.startswith(tmpdir)) + self.assertTrue(lock.lock_file.endswith(".lock")) + + def test_different_models_different_locks(self): + with tempfile.TemporaryDirectory() as tmpdir: + lock1 = _get_lock("org/model-a", cache_dir=tmpdir) + lock2 = _get_lock("org/model-b", cache_dir=tmpdir) + self.assertNotEqual(lock1.lock_file, lock2.lock_file) + + +class TestResolveModelPath(unittest.TestCase): + def test_local_directory_returned_as_is(self): + with tempfile.TemporaryDirectory() as tmpdir: + result = resolve_model_path(tmpdir) + self.assertEqual(result, tmpdir) + + @patch("flagscale.models.utils.hub_utils.use_modelscope", return_value=False) + def test_hf_download(self, _mock_use_ms): + with patch("huggingface_hub.snapshot_download", return_value="/cache/model") as mock_dl: + result = resolve_model_path("org/model", revision="main", cache_dir="/tmp/cache") + mock_dl.assert_called_once_with( + "org/model", + repo_type="model", + revision="main", + cache_dir="/tmp/cache", + allow_patterns=None, + ignore_patterns=None, + ) + self.assertEqual(result, "/cache/model") + + @patch("flagscale.models.utils.hub_utils.use_modelscope", return_value=True) + def test_modelscope_download(self, _mock_use_ms): + mock_ms_module = MagicMock() + mock_ms_module.snapshot_download.return_value = "/cache/ms_model" + with patch.dict("sys.modules", {"modelscope.hub.snapshot_download": mock_ms_module}): + result = resolve_model_path("org/model", revision="v1", cache_dir="/tmp/ms") + mock_ms_module.snapshot_download.assert_called_once_with( + model_id="org/model", + cache_dir="/tmp/ms", + revision="v1", + ignore_file_pattern=None, + allow_patterns=None, + ) + self.assertEqual(result, "/cache/ms_model") + + @patch("flagscale.models.utils.hub_utils.use_modelscope", return_value=False) + def test_raises_on_empty_download_result(self, _mock_use_ms): + with ( + patch("huggingface_hub.snapshot_download", return_value=None), + self.assertRaises(RuntimeError), + ): + resolve_model_path("org/model") + + def test_allow_and_ignore_patterns(self): + with ( + patch("flagscale.models.utils.hub_utils.use_modelscope", return_value=False), + patch("huggingface_hub.snapshot_download", return_value="/cache/m") as mock_dl, + ): + resolve_model_path( + "org/model", + allow_patterns=["*.safetensors"], + ignore_patterns=["*.bin"], + ) + _, kwargs = mock_dl.call_args + self.assertEqual(kwargs["allow_patterns"], ["*.safetensors"]) + self.assertEqual(kwargs["ignore_patterns"], ["*.bin"]) + + +if __name__ == "__main__": + unittest.main()