diff --git a/cosmos_transfer2/_src/imaginaire/utils/checkpoint_db.py b/cosmos_transfer2/_src/imaginaire/utils/checkpoint_db.py index a20dafba..b58d4f26 100644 --- a/cosmos_transfer2/_src/imaginaire/utils/checkpoint_db.py +++ b/cosmos_transfer2/_src/imaginaire/utils/checkpoint_db.py @@ -197,17 +197,32 @@ class CheckpointFileHf(_CheckpointHf): @override def _download(self) -> str: """Download checkpoint and return the local path.""" - cmd_args = [ - self.repository, - "--repo-type", - "model", - "--revision", - self.revision, - self.filename, - ] - path = _hf_download(cmd_args) - assert os.path.exists(path), path - return path + repositories = [self.repository] + if self.repository in ["nvidia/Cosmos-Experimental", "nvidia-cosmos-ea/Cosmos-Experimental"]: + repositories = [ + self.repository, + "nvidia-cosmos-ea/Cosmos-Experimental" + if self.repository == "nvidia/Cosmos-Experimental" + else "nvidia/Cosmos-Experimental", + ] + + for repository in repositories: + try: + cmd_args = [ + repository, + "--repo-type", + "model", + "--revision", + self.revision, + self.filename, + ] + path = _hf_download(cmd_args) + assert os.path.exists(path), path + return path + except subprocess.CalledProcessError: + continue + + raise RuntimeError(f"Failed to download {self.filename} from any experimental repository.") class CheckpointDirHf(_CheckpointHf): diff --git a/cosmos_transfer2/_src/imaginaire/utils/checkpoint_db_test.py b/cosmos_transfer2/_src/imaginaire/utils/checkpoint_db_test.py index e0d63685..3ced6b39 100644 --- a/cosmos_transfer2/_src/imaginaire/utils/checkpoint_db_test.py +++ b/cosmos_transfer2/_src/imaginaire/utils/checkpoint_db_test.py @@ -136,3 +136,104 @@ def test_get_checkpoint_dir(): path = Path(path) assert path.is_dir() assert path.joinpath(CHECKPOINT_HF_FILENAME).is_file() + + +@pytest.mark.L0 +def test_experimental_repository_fallback(): + """Test fallback between experimental repositories.""" + import subprocess + from unittest.mock import patch + + experimental_repo = "nvidia/Cosmos-Experimental" + fallback_repo = "nvidia-cosmos-ea/Cosmos-Experimental" + + if INTERNAL: + return + + config = CheckpointFileHf( + repository=experimental_repo, + revision="main", + filename="test_file.pt", + ) + + call_count = [0] + + def mock_hf_download(cmd_args): + """Mock _hf_download to simulate first failure, second success.""" + call_count[0] += 1 + repo = cmd_args[0] + + if call_count[0] == 1: + # First call (experimental repo) - fail + assert repo == experimental_repo + raise subprocess.CalledProcessError(1, cmd_args) + elif call_count[0] == 2: + # Second call (fallback repo) - succeed + assert repo == fallback_repo + return "/path/to/test_file.pt" + else: + raise AssertionError(f"Unexpected call count: {call_count[0]}") + + # Mock os.path.exists to avoid file system checks in test + with patch("cosmos_transfer2._src.imaginaire.utils.checkpoint_db._hf_download", side_effect=mock_hf_download): + with patch("os.path.exists", return_value=True): + path = config._download() + assert path == "/path/to/test_file.pt" + assert call_count[0] == 2 + + +@pytest.mark.L0 +def test_non_experimental_repository_no_fallback(): + """Test that non-experimental repositories don't use fallback.""" + from unittest.mock import patch + + normal_repo = "nvidia/Cosmos-Predict2.5-2B" + + if INTERNAL: + return + + config = CheckpointFileHf( + repository=normal_repo, + revision="main", + filename="test_file.pt", + ) + + call_count = [0] + + def mock_hf_download(cmd_args): + """Mock _hf_download to simulate success on first call.""" + call_count[0] += 1 + repo = cmd_args[0] + + assert repo == normal_repo + return "/path/to/test_file.pt" + + with patch("cosmos_transfer2._src.imaginaire.utils.checkpoint_db._hf_download", side_effect=mock_hf_download): + with patch("os.path.exists", return_value=True): + path = config._download() + assert path == "/path/to/test_file.pt" + assert call_count[0] == 1 + + +@pytest.mark.L0 +def test_both_repositories_fail(): + """Test that RuntimeError is raised when both repositories fail.""" + import subprocess + from unittest.mock import patch + + if INTERNAL: + return + + config = CheckpointFileHf( + repository="nvidia/Cosmos-Experimental", + revision="main", + filename="test_file.pt", + ) + + def mock_hf_download_fail(cmd_args): + """Mock _hf_download to always fail.""" + raise subprocess.CalledProcessError(1, cmd_args) + + with patch("cosmos_transfer2._src.imaginaire.utils.checkpoint_db._hf_download", side_effect=mock_hf_download_fail): + with pytest.raises(RuntimeError, match="Failed to download test_file.pt from any experimental repository"): + config._download()