Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/gpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,11 @@ jobs:

# One inference call over t2vs (+sound), action policy, and forward_dynamics; checks each output.
# MAX_GPUS defaults to 8. -s streams the live process log.
# Reuse the same input-asset cache dir as the unittest job.
- name: Nano inference smoke (t2vs + action policy + forward_dynamics, 8 GPU)
run: |
export LD_LIBRARY_PATH=
export COSMOS_DOWNLOAD_CACHE_DIR="$RUNNER_WORKSPACE/cosmos_input_cache"
uv run --all-extras --group=cu128-train python -m pytest -v -s \
tests/nano_inference_smoke_test.py --num-gpus=8 --levels=2 -o addopts=

Expand Down Expand Up @@ -193,9 +195,12 @@ jobs:
# is absent (via RunIf / pytest.skip guards), so this is green without
# internal credentials; provide the credential file on the runner to
# exercise them. New tests are picked up automatically (no markers/lists).
# Cache downloaded input assets in a persistent dir (outside the repo tree,
# so the cleanup step keeps it) and reuse it across runs.
- name: Unit tests
run: |
export LD_LIBRARY_PATH=
export COSMOS_DOWNLOAD_CACHE_DIR="$RUNNER_WORKSPACE/cosmos_input_cache"
uv run --all-extras --group=cu128-train python -m pytest -v -s \
cosmos_framework/ -o addopts=

Expand Down
77 changes: 73 additions & 4 deletions cosmos_framework/inference/common/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@

import contextlib
import glob
import hashlib
import itertools
import json
import os
import random
import re
import tempfile
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -52,7 +55,47 @@
MEDIA_EXTENSIONS = IMAGE_EXTENSIONS + VIDEO_EXTENSIONS


# Retry transient download errors with exponential backoff (env-overridable).
_DOWNLOAD_MAX_ATTEMPTS = int(os.environ.get("COSMOS_DOWNLOAD_MAX_ATTEMPTS", "6"))
_DOWNLOAD_BACKOFF_BASE_S = float(os.environ.get("COSMOS_DOWNLOAD_BACKOFF_S", "4"))
_DOWNLOAD_BACKOFF_CAP_S = float(os.environ.get("COSMOS_DOWNLOAD_BACKOFF_CAP_S", "60"))

# Statuses not worth retrying.
_PERMANENT_HTTP_MARKERS = ("400 Bad Request", "401 Unauthorized", "403 Forbidden", "404 Not Found")


def _is_permanent_download_error(exc: BaseException) -> bool:
if type(exc).__name__ in {"NotFoundError", "PermissionError"}:
return True
msg = str(exc)
return any(marker in msg for marker in _PERMANENT_HTTP_MARKERS)


def _download_file_url(url: str, path: Path):
"""Download ``url`` to ``path``, retrying transient network/server errors."""
from cosmos_framework.utils import log

last_exc: BaseException | None = None
for attempt in range(1, _DOWNLOAD_MAX_ATTEMPTS + 1):
try:
_download_file_url_once(url, path)
return
except Exception as exc: # noqa: BLE001
last_exc = exc
if _is_permanent_download_error(exc) or attempt == _DOWNLOAD_MAX_ATTEMPTS:
break
delay = min(_DOWNLOAD_BACKOFF_CAP_S, _DOWNLOAD_BACKOFF_BASE_S * 2 ** (attempt - 1))
delay += random.uniform(0, delay * 0.25) # jitter
log.warning(
f"Download attempt {attempt}/{_DOWNLOAD_MAX_ATTEMPTS} for {url} failed "
f"({type(exc).__name__}: {exc}); retrying in {delay:.1f}s."
)
time.sleep(delay)

raise RuntimeError(f"Failed to download {url} after {_DOWNLOAD_MAX_ATTEMPTS} attempt(s)") from last_exc


def _download_file_url_once(url: str, path: Path):
if "huggingface.co" in url:
_download_file_hf(url, path)
else:
Expand Down Expand Up @@ -85,6 +128,33 @@ def _download_file_hf(url: str, path: Path):
f.write(chunk)


def _resolve_url_download(url: str, name: str) -> Path:
"""Fetch ``url`` to a local file and return its path.

When ``COSMOS_DOWNLOAD_CACHE_DIR`` is set, downloads are cached there by URL
and reused across runs; otherwise a fresh temp dir is used per download.
"""
cache_root = os.environ.get("COSMOS_DOWNLOAD_CACHE_DIR")
if not cache_root:
local_path = Path(tempfile.mkdtemp()) / name
_download_file_url(url, local_path)
return local_path

cache_dir = Path(cache_root)
digest = hashlib.sha256(url.encode()).hexdigest()[:16]
cache_path = cache_dir / f"{digest}-{name}"
done_marker = Path(f"{cache_path}.done")
if cache_path.exists() and done_marker.exists():
return cache_path
cache_dir.mkdir(parents=True, exist_ok=True)
# Atomic move so concurrent writers never observe a half-written file.
tmp_path = cache_path.with_name(f"{cache_path.name}.{os.getpid()}.tmp")
_download_file_url(url, tmp_path)
os.replace(tmp_path, cache_path)
done_marker.write_text(url)
return cache_path


def _download_file(url: str, path: Path):
if "://" not in url and Path(url).resolve() == path.resolve():
return
Expand All @@ -94,10 +164,9 @@ def _download_file(url: str, path: Path):
return

if "://" in url:
# Download to a temporary directory and symlink to the final path.
# This keeps the output directory small.
local_path = Path(tempfile.TemporaryDirectory(delete=False).name) / path.name
_download_file_url(url, local_path)
# Download (optionally via the persistent cache) and symlink to the final
# path. This keeps the output directory small.
local_path = _resolve_url_download(url, path.name)
else:
local_path = Path(url)

Expand Down
7 changes: 6 additions & 1 deletion cosmos_framework/inference/common/args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import os
from pathlib import Path

import pytest

from cosmos_framework.inference.args import DEFAULT_CHECKPOINT, DEFAULT_CHECKPOINT_NAME
from cosmos_framework.inference.common.args import CheckpointConfig, CheckpointOverrides, download_file

Expand All @@ -13,7 +15,10 @@
}


def test_download_file(tmp_path: Path):
def test_download_file(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
# Disable the URL cache; this test asserts each download is independent.
monkeypatch.delenv("COSMOS_DOWNLOAD_CACHE_DIR", raising=False)

download_url_1 = (
"https://github.com/nvidia-cosmos/cosmos-dependencies/raw/2b17a2413bd86b2cf9b03823637108851e4ddf2d/inputs/vision/robot_153.jpg"
)
Expand Down