Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions cosmos_transfer2/_src/imaginaire/utils/checkpoint_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,17 +197,32 @@ class CheckpointFileHf(_CheckpointHf):
@override
def _download(self) -> str:
"""Download checkpoint and return the local path."""
cmd_args = [
self.repository,
"--repo-type",
"model",
"--revision",
self.revision,
self.filename,
]
path = _hf_download(cmd_args)
assert os.path.exists(path), path
return path
repositories = [self.repository]
if self.repository in ["nvidia/Cosmos-Experimental", "nvidia-cosmos-ea/Cosmos-Experimental"]:
repositories = [
self.repository,
"nvidia-cosmos-ea/Cosmos-Experimental"
if self.repository == "nvidia/Cosmos-Experimental"
else "nvidia/Cosmos-Experimental",
]

for repository in repositories:
try:
cmd_args = [
repository,
"--repo-type",
"model",
"--revision",
self.revision,
self.filename,
]
path = _hf_download(cmd_args)
assert os.path.exists(path), path
return path
except subprocess.CalledProcessError:
continue

raise RuntimeError(f"Failed to download {self.filename} from any experimental repository.")


class CheckpointDirHf(_CheckpointHf):
Expand Down
101 changes: 101 additions & 0 deletions cosmos_transfer2/_src/imaginaire/utils/checkpoint_db_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,104 @@ def test_get_checkpoint_dir():
path = Path(path)
assert path.is_dir()
assert path.joinpath(CHECKPOINT_HF_FILENAME).is_file()


@pytest.mark.L0
def test_experimental_repository_fallback():
"""Test fallback between experimental repositories."""
import subprocess
from unittest.mock import patch

experimental_repo = "nvidia/Cosmos-Experimental"
fallback_repo = "nvidia-cosmos-ea/Cosmos-Experimental"

if INTERNAL:
return

config = CheckpointFileHf(
repository=experimental_repo,
revision="main",
filename="test_file.pt",
)

call_count = [0]

def mock_hf_download(cmd_args):
"""Mock _hf_download to simulate first failure, second success."""
call_count[0] += 1
repo = cmd_args[0]

if call_count[0] == 1:
# First call (experimental repo) - fail
assert repo == experimental_repo
raise subprocess.CalledProcessError(1, cmd_args)
elif call_count[0] == 2:
# Second call (fallback repo) - succeed
assert repo == fallback_repo
return "/path/to/test_file.pt"
else:
raise AssertionError(f"Unexpected call count: {call_count[0]}")

# Mock os.path.exists to avoid file system checks in test
with patch("cosmos_transfer2._src.imaginaire.utils.checkpoint_db._hf_download", side_effect=mock_hf_download):
with patch("os.path.exists", return_value=True):
path = config._download()
assert path == "/path/to/test_file.pt"
assert call_count[0] == 2


@pytest.mark.L0
def test_non_experimental_repository_no_fallback():
"""Test that non-experimental repositories don't use fallback."""
from unittest.mock import patch

normal_repo = "nvidia/Cosmos-Predict2.5-2B"

if INTERNAL:
return

config = CheckpointFileHf(
repository=normal_repo,
revision="main",
filename="test_file.pt",
)

call_count = [0]

def mock_hf_download(cmd_args):
"""Mock _hf_download to simulate success on first call."""
call_count[0] += 1
repo = cmd_args[0]

assert repo == normal_repo
return "/path/to/test_file.pt"

with patch("cosmos_transfer2._src.imaginaire.utils.checkpoint_db._hf_download", side_effect=mock_hf_download):
with patch("os.path.exists", return_value=True):
path = config._download()
assert path == "/path/to/test_file.pt"
assert call_count[0] == 1


@pytest.mark.L0
def test_both_repositories_fail():
"""Test that RuntimeError is raised when both repositories fail."""
import subprocess
from unittest.mock import patch

if INTERNAL:
return

config = CheckpointFileHf(
repository="nvidia/Cosmos-Experimental",
revision="main",
filename="test_file.pt",
)

def mock_hf_download_fail(cmd_args):
"""Mock _hf_download to always fail."""
raise subprocess.CalledProcessError(1, cmd_args)

with patch("cosmos_transfer2._src.imaginaire.utils.checkpoint_db._hf_download", side_effect=mock_hf_download_fail):
with pytest.raises(RuntimeError, match="Failed to download test_file.pt from any experimental repository"):
config._download()