Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cosmos_framework/configs/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ class DataSetting:
@attrs.define(slots=False)
class Config(config.Config):
data_setting: DataSetting = attrs.field(factory=DataSetting)
# Free-form, project-owned escape hatch fed by the SFT TOML's [custom] section.
# Default-empty so ${custom} interpolation and config.custom always resolve.
custom: dict = attrs.field(factory=dict)
Comment thread
yy-code-nv marked this conversation as resolved.
Outdated
defaults: List[Any] = attrs.field(
factory=lambda: [
"_self_",
Expand Down
3 changes: 3 additions & 0 deletions cosmos_framework/configs/base/vlm/defaults/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class DataSetting:
class Config(config.Config):
policy: PolicyConfig = PolicyConfig()
data_setting: DataSetting = DataSetting()
# Free-form, project-owned escape hatch fed by the SFT TOML's [custom] section.
# Default-empty so ${custom} interpolation and config.custom always resolve.
custom: dict = attrs.field(factory=dict)
defaults: List[Any] = attrs.field(
factory=lambda: [
"_self_",
Expand Down
37 changes: 27 additions & 10 deletions cosmos_framework/configs/toml_config/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,14 @@ class SFTExperimentConfig(BaseModel):
trainer: TrainerConfig = Field(default_factory=TrainerConfig)
checkpoint: CheckpointConfig = Field(default_factory=CheckpointConfig)
dataloader_train: DataloaderTrainConfig = Field(default_factory=DataloaderTrainConfig)
custom: dict[str, Any] = Field(
default_factory=dict,
description=(
"Free-form, project-owned escape hatch. Arbitrary nested content "
"passes through verbatim — the framework never validates inside it. "
"Reachable as config.custom and via '${custom}' interpolation."
),
)


# ---------------------------------------------------------------------------
Expand All @@ -693,16 +701,16 @@ def load_experiment_from_toml(
["optimizer.lr=1e-5", "trainer.max_iter=200"]
["model.config.parallelism.data_parallel_shard_degree=4"]

Calls ``cosmos_framework.utils.config.load_config`` which:
The load then:

1. Imports the base config module and runs ``make_config()``. This
registers every config group (model, ema, tokenizer, ...) and imports
all experiment modules so their ``cs.store(group="experiment", ...)``
side-effects fire.
2. Runs ``override(config, overrides)`` — Hydra ``compose`` then resolves
the ``experiment=<name>`` selector against ``ConfigStore`` and applies
the dotted-path overrides we generated from the TOML, followed by
``extra_overrides``.
1. Imports the base config module and runs ``make_config()`` (registers
config groups + experiment modules).
2. Sets the TOML's ``[custom]`` table (if any) onto ``config.custom`` so it
is part of the OmegaConf tree Hydra resolves — kept out of
``build_hydra_overrides`` so it lands verbatim, not per-leaf-remapped.
3. Runs ``override(config, overrides)`` — Hydra ``compose`` resolves the
``experiment=<name>`` selector and applies the dotted-path overrides,
followed by ``extra_overrides``.

Returns the merged ``Config`` instance, ready for ``launch()``.
"""
Expand Down Expand Up @@ -740,4 +748,13 @@ def load_experiment_from_toml(
# Import lazily so this module stays cheap to import in non-training contexts.
from cosmos_framework.utils.config import load_config

return load_config(base_config_path, overrides)
# Inject [custom] before override() (via load_config's pre_override hook) so
# it is part of the OmegaConf tree Hydra resolves (enables ${custom}
# interpolation) and lands verbatim rather than being per-leaf-remapped.
custom = raw.get("custom")

def _inject_custom(config: Any) -> None:
if custom:
config.custom = custom

return load_config(base_config_path, overrides, pre_override=_inject_custom)
247 changes: 247 additions & 0 deletions cosmos_framework/configs/toml_config/sft_config_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: OpenMDW-1.1

"""Loader tests for the free-form ``[custom]`` escape-hatch section of the SFT TOML."""

from __future__ import annotations

from pathlib import Path

import attrs
import pytest
from pydantic import ValidationError

from cosmos_framework.configs.toml_config.sft_config import SFTExperimentConfig
from cosmos_framework.configs.toml_config.toml_config_helper import build_hydra_overrides

# Representative payload: scalars, a nested sub-table, and an array-of-tables.
_CUSTOM_PAYLOAD = {
"scalar_int": 5,
"scalar_str": "hello",
"flag": True,
"ratio": 0.3,
"sampling": {"bug_ratio": 0.3, "nested": {"deep": 1}},
"items": [
{"path": "/data/a", "weight": 1.0},
{"path": "/data/b", "weight": 2.0},
],
}


# --------------------------------------------------------------------------- #
# 1. pydantic schema validation #
# --------------------------------------------------------------------------- #
class TestSchemaValidation:
def test_custom_section_validates_arbitrary_nested_content(self) -> None:
"""Arbitrary nested [custom] content passes through untouched."""
raw = {
"job": {"task": "vfm", "experiment": "vision_sft_nano"},
"custom": _CUSTOM_PAYLOAD,
}
cfg = SFTExperimentConfig.model_validate(raw)
# The framework stores it verbatim — no coercion, no inner validation.
assert cfg.custom == _CUSTOM_PAYLOAD

def test_no_custom_section_defaults_empty(self) -> None:
cfg = SFTExperimentConfig.model_validate({"job": {"task": "vfm", "experiment": "vision_sft_nano"}})
assert cfg.custom == {}

def test_unknown_top_level_key_raises(self) -> None:
"""Any unknown top-level section that is NOT `custom` still raises."""
with pytest.raises(ValidationError):
SFTExperimentConfig.model_validate(
{
"job": {"task": "vfm", "experiment": "vision_sft_nano"},
"bogus_section": {"x": 1},
}
)

def test_unknown_key_inside_optimizer_raises(self) -> None:
"""A typo inside a KNOWN section is still a hard error (extra='forbid')."""
with pytest.raises(ValidationError):
SFTExperimentConfig.model_validate(
{
"job": {"task": "vfm", "experiment": "vision_sft_nano"},
"optimizer": {"lr": 1.0e-4, "not_a_real_key": 1},
}
)

def test_custom_does_not_loosen_sibling_validation(self) -> None:
"""Presence of [custom] must not relax extra='forbid' elsewhere."""
with pytest.raises(ValidationError):
SFTExperimentConfig.model_validate(
{
"job": {"task": "vfm", "experiment": "vision_sft_nano"},
"custom": _CUSTOM_PAYLOAD,
"trainer": {"max_iter": 10, "typo_here": True},
}
)


# --------------------------------------------------------------------------- #
# 2. build_hydra_overrides must NOT emit [custom] as per-leaf overrides #
# --------------------------------------------------------------------------- #
class TestBuildHydraOverrides:
def test_custom_not_emitted_as_overrides(self) -> None:
raw = {
"job": {"task": "vfm", "experiment": "vision_sft_nano"},
"optimizer": {"lr": 1.0e-5},
"custom": _CUSTOM_PAYLOAD,
}
overrides = build_hydra_overrides(raw)
# Nothing under custom (verbatim or remapped) should appear.
assert all("custom" not in o for o in overrides), overrides

def test_other_keys_still_emitted(self) -> None:
raw = {
"job": {"task": "vfm", "experiment": "vision_sft_nano"},
"optimizer": {"lr": 1.0e-5},
"custom": {"a": 1},
}
overrides = build_hydra_overrides(raw)
assert "experiment=vision_sft_nano" in overrides
assert any(o.startswith("optimizer.lr=") for o in overrides), overrides


# --------------------------------------------------------------------------- #
# 3. interpolation (${custom}) + attribute access via the real override path #
# --------------------------------------------------------------------------- #
@attrs.define(slots=False)
class _ProbeConfig:
"""Minimal attrs config (no ``defaults`` field) so ``override`` resolves it
without needing registered config groups / the training stack."""

custom: dict = attrs.field(factory=dict)
dataloader_train: object = None


class TestCustomInterpolationAndAccess:
Comment thread
foreverlms marked this conversation as resolved.
Outdated
def test_interpolation_resolves_and_attribute_access(self) -> None:
from omegaconf import OmegaConf

from cosmos_framework.utils.config_helper import override
from cosmos_framework.utils.lazy_config import instantiate

# builtins.dict is a LazyCall-shaped, instantiate-locatable target.
probe = _ProbeConfig(
custom=dict(_CUSTOM_PAYLOAD),
dataloader_train={
"_target_": "builtins.dict",
"whole": "${custom}",
"leaf": "${custom.scalar_int}",
"deep": "${custom.sampling.nested.deep}",
},
)

resolved = override(probe, ["--"])

# attribute access + to_container round-trips to the plain dict.
assert OmegaConf.to_container(resolved.custom, resolve=True) == _CUSTOM_PAYLOAD

# ${custom...} interpolations resolved against the top-level node.
dl = resolved.dataloader_train
assert OmegaConf.to_container(dl["whole"], resolve=True) == _CUSTOM_PAYLOAD
assert dl["leaf"] == _CUSTOM_PAYLOAD["scalar_int"]
assert dl["deep"] == _CUSTOM_PAYLOAD["sampling"]["nested"]["deep"]

# instantiate() yields the resolved content.
obj = instantiate(resolved.dataloader_train)
assert obj["leaf"] == _CUSTOM_PAYLOAD["scalar_int"]
assert obj["deep"] == _CUSTOM_PAYLOAD["sampling"]["nested"]["deep"]
assert OmegaConf.to_container(obj["whole"], resolve=True) == _CUSTOM_PAYLOAD


# --------------------------------------------------------------------------- #
# 4. end-to-end load_experiment_from_toml on the shipped vision_sft_nano recipe #
# --------------------------------------------------------------------------- #
_BASE_TOML = """\
[job]
task = "vfm"
experiment = "vision_sft_nano"
project = "cosmos3"
group = "sft"
name = "sft_config_custom_test"
wandb_mode = "disabled"

[model.tokenizer]
vae_path = "${oc.env:WAN_VAE_PATH}"

[checkpoint]
load_path = "${oc.env:BASE_CHECKPOINT_PATH}"
"""

_CUSTOM_TOML_BLOCK = """\

[custom]
scalar_int = 5
scalar_str = "hello"
flag = true
ratio = 0.3

[custom.sampling]
bug_ratio = 0.3

[custom.sampling.nested]
deep = 1

[[custom.items]]
path = "/data/a"
weight = 1.0

[[custom.items]]
path = "/data/b"
weight = 2.0
"""


def _load_or_skip(toml_path: Path):
"""Run the real loader, skipping if the training stack can't be imported."""
from cosmos_framework.configs.toml_config.sft_config import load_experiment_from_toml

try:
return load_experiment_from_toml(str(toml_path))
except ImportError as exc: # pragma: no cover — env-dependent
pytest.skip(f"training stack not importable here: {exc!r}")


@pytest.fixture
def _dummy_recipe_env(monkeypatch: pytest.MonkeyPatch) -> None:
# vision_sft_nano interpolates these env vars into path strings at resolve time.
monkeypatch.setenv("DATASET_PATH", "/tmp/dummy_dataset")
monkeypatch.setenv("WAN_VAE_PATH", "/tmp/dummy_vae.pth")
monkeypatch.setenv("BASE_CHECKPOINT_PATH", "/tmp/dummy_ckpt")


class TestEndToEndLoader:
def test_load_with_custom_section(self, tmp_path: Path, _dummy_recipe_env: None) -> None:
from omegaconf import OmegaConf

toml_path = tmp_path / "with_custom.toml"
toml_path.write_text(_BASE_TOML + _CUSTOM_TOML_BLOCK)

config = _load_or_skip(toml_path)

expected = {
"scalar_int": 5,
"scalar_str": "hello",
"flag": True,
"ratio": 0.3,
"sampling": {"bug_ratio": 0.3, "nested": {"deep": 1}},
"items": [
{"path": "/data/a", "weight": 1.0},
{"path": "/data/b", "weight": 2.0},
],
}
# Reachable by attribute access and convertible to a plain dict, so a
# project can run MyProjectConfig.model_validate(<that dict>).
assert OmegaConf.to_container(config.custom, resolve=True) == expected

def test_load_without_custom_section_defaults_empty(self, tmp_path: Path, _dummy_recipe_env: None) -> None:
from omegaconf import OmegaConf

toml_path = tmp_path / "no_custom.toml"
toml_path.write_text(_BASE_TOML)

config = _load_or_skip(toml_path)

assert OmegaConf.to_container(config.custom, resolve=True) == {}
3 changes: 3 additions & 0 deletions cosmos_framework/configs/toml_config/toml_config_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ def build_hydra_overrides(toml_dict: dict) -> list[str]:

overlay = dict(toml_dict)
overlay["job"] = job
# [custom] lands verbatim on config.custom (see load_experiment_from_toml),
# so it must not be per-leaf-remapped into Hydra overrides here.
overlay.pop("custom", None)

for top_key, val in overlay.items():
_emit_with_remap(overrides, [top_key], val, rules)
Expand Down
29 changes: 25 additions & 4 deletions cosmos_framework/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import importlib
import os
import time
from typing import Any, Dict, Optional, Type, TypeVar, Union
from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union

import attrs
import torch
Expand Down Expand Up @@ -517,7 +517,18 @@ def validate(self) -> None:
assert self.job.name != ""


def load_config(config_path: str, opts: list[str], enable_one_logger: bool = False) -> Config:
def load_config(
config_path: str,
opts: list[str],
enable_one_logger: bool = False,
pre_override: Optional[Callable[[Config], None]] = None,
) -> Config:
"""Load a config from a ``.yaml`` or ``.py`` path and apply ``opts``.

``pre_override`` is an optional hook called on the freshly-built config
(after ``make_config()``, before ``override()``) — use it to mutate the
config so the change is part of the OmegaConf tree Hydra resolves.
"""
from cosmos_framework.utils.serialization import from_yaml, load_callable

t1 = time.monotonic_ns()
Expand All @@ -528,9 +539,11 @@ def load_config(config_path: str, opts: list[str], enable_one_logger: bool = Fal

from cosmos_framework.utils.config_helper import override

if pre_override is not None:
pre_override(config)
config = override(config, opts, remove_defaults=True)
else:
config = _load_py_config(config_path, opts, validate=False)
config = _load_py_config(config_path, opts, validate=False, pre_override=pre_override)

if enable_one_logger:
try:
Expand All @@ -549,7 +562,12 @@ def load_config(config_path: str, opts: list[str], enable_one_logger: bool = Fal
return config


def _load_py_config(config_path: str, opts: list[str], validate: bool = True) -> Config:
def _load_py_config(
config_path: str,
opts: list[str],
validate: bool = True,
pre_override: Optional[Callable[[Config], None]] = None,
) -> Config:
# NOTE: circular dependency
from cosmos_framework.utils.config_helper import get_config_module, override

Expand All @@ -563,6 +581,9 @@ def _load_py_config(config_path: str, opts: list[str], validate: bool = True) ->
t2 = time.monotonic_ns()
logging.debug(f"importlib.import_module: took {(t2 - t1) / 1e6:.2f}ms")

if pre_override is not None:
pre_override(config)

t1 = time.monotonic_ns()
config = override(config, opts)
t2 = time.monotonic_ns()
Expand Down
Loading