From a19b1432b275253ad687c463a5d582187f579501 Mon Sep 17 00:00:00 2001 From: Alex Peng Date: Thu, 25 Jun 2026 14:22:56 +0800 Subject: [PATCH 1/6] Add custom section support to SFT TOML configuration Signed-off-by: Alex Peng --- cosmos_framework/configs/base/config.py | 3 + .../configs/base/vlm/defaults/config.py | 3 + .../configs/toml_config/sft_config.py | 42 ++- .../configs/toml_config/toml_config_helper.py | 3 + docs/sft_config.md | 35 ++- tests/toml_custom_section_test.py | 247 ++++++++++++++++++ 6 files changed, 320 insertions(+), 13 deletions(-) create mode 100644 tests/toml_custom_section_test.py diff --git a/cosmos_framework/configs/base/config.py b/cosmos_framework/configs/base/config.py index e766c5c..5293c23 100644 --- a/cosmos_framework/configs/base/config.py +++ b/cosmos_framework/configs/base/config.py @@ -25,6 +25,9 @@ class DataSetting: @attrs.define(slots=False) class Config(config.Config): data_setting: DataSetting = attrs.field(factory=DataSetting) + # Free-form, project-owned escape hatch fed by the SFT TOML's [custom] section. + # Default-empty so ${custom} interpolation and config.custom always resolve. + custom: dict = attrs.field(factory=dict) defaults: List[Any] = attrs.field( factory=lambda: [ "_self_", diff --git a/cosmos_framework/configs/base/vlm/defaults/config.py b/cosmos_framework/configs/base/vlm/defaults/config.py index 2128072..6145903 100644 --- a/cosmos_framework/configs/base/vlm/defaults/config.py +++ b/cosmos_framework/configs/base/vlm/defaults/config.py @@ -52,6 +52,9 @@ class DataSetting: class Config(config.Config): policy: PolicyConfig = PolicyConfig() data_setting: DataSetting = DataSetting() + # Free-form, project-owned escape hatch fed by the SFT TOML's [custom] section. + # Default-empty so ${custom} interpolation and config.custom always resolve. + custom: dict = attrs.field(factory=dict) defaults: List[Any] = attrs.field( factory=lambda: [ "_self_", diff --git a/cosmos_framework/configs/toml_config/sft_config.py b/cosmos_framework/configs/toml_config/sft_config.py index a15d00a..611b41c 100644 --- a/cosmos_framework/configs/toml_config/sft_config.py +++ b/cosmos_framework/configs/toml_config/sft_config.py @@ -669,6 +669,14 @@ class SFTExperimentConfig(BaseModel): trainer: TrainerConfig = Field(default_factory=TrainerConfig) checkpoint: CheckpointConfig = Field(default_factory=CheckpointConfig) dataloader_train: DataloaderTrainConfig = Field(default_factory=DataloaderTrainConfig) + custom: dict[str, Any] = Field( + default_factory=dict, + description=( + "Free-form, project-owned escape hatch. Arbitrary nested content " + "passes through verbatim — the framework never validates inside it. " + "Reachable as config.custom and via '${custom}' interpolation." + ), + ) # --------------------------------------------------------------------------- @@ -693,16 +701,16 @@ def load_experiment_from_toml( ["optimizer.lr=1e-5", "trainer.max_iter=200"] ["model.config.parallelism.data_parallel_shard_degree=4"] - Calls ``cosmos_framework.utils.config.load_config`` which: + The load then: - 1. Imports the base config module and runs ``make_config()``. This - registers every config group (model, ema, tokenizer, ...) and imports - all experiment modules so their ``cs.store(group="experiment", ...)`` - side-effects fire. - 2. Runs ``override(config, overrides)`` — Hydra ``compose`` then resolves - the ``experiment=`` selector against ``ConfigStore`` and applies - the dotted-path overrides we generated from the TOML, followed by - ``extra_overrides``. + 1. Imports the base config module and runs ``make_config()`` (registers + config groups + experiment modules). + 2. Sets the TOML's ``[custom]`` table (if any) onto ``config.custom`` so it + is part of the OmegaConf tree Hydra resolves — kept out of + ``build_hydra_overrides`` so it lands verbatim, not per-leaf-remapped. + 3. Runs ``override(config, overrides)`` — Hydra ``compose`` resolves the + ``experiment=`` selector and applies the dotted-path overrides, + followed by ``extra_overrides``. Returns the merged ``Config`` instance, ready for ``launch()``. """ @@ -738,6 +746,18 @@ def load_experiment_from_toml( overrides.append(o) # Import lazily so this module stays cheap to import in non-training contexts. - from cosmos_framework.utils.config import load_config + import importlib + + from cosmos_framework.utils.config_helper import get_config_module, override + + # Set [custom] on the base Config before override() so it is part of the + # OmegaConf tree Hydra resolves (enables ${custom} interpolation) and lands + # verbatim rather than being per-leaf-remapped by build_hydra_overrides. + config_module = get_config_module(base_config_path) + config = importlib.import_module(config_module).make_config() + + custom = raw.get("custom") + if custom: + config.custom = custom - return load_config(base_config_path, overrides) + return override(config, overrides) diff --git a/cosmos_framework/configs/toml_config/toml_config_helper.py b/cosmos_framework/configs/toml_config/toml_config_helper.py index ac54696..82333cf 100644 --- a/cosmos_framework/configs/toml_config/toml_config_helper.py +++ b/cosmos_framework/configs/toml_config/toml_config_helper.py @@ -138,6 +138,9 @@ def build_hydra_overrides(toml_dict: dict) -> list[str]: overlay = dict(toml_dict) overlay["job"] = job + # [custom] lands verbatim on config.custom (see load_experiment_from_toml), + # so it must not be per-leaf-remapped into Hydra overrides here. + overlay.pop("custom", None) for top_key, val in overlay.items(): _emit_with_remap(overrides, [top_key], val, rules) diff --git a/docs/sft_config.md b/docs/sft_config.md index 6762f40..6b79dab 100644 --- a/docs/sft_config.md +++ b/docs/sft_config.md @@ -23,6 +23,7 @@ ______________________________________________________________________ - [`[trainer.callbacks.grad_clip]`](#trainercallbacksgrad_clip) - [`[checkpoint]`](#checkpoint) - [`[dataloader_train]`](#dataloader_train) +- [`[custom]` (free-form escape hatch)](#custom-free-form-escape-hatch) - [Cross-cutting behaviors](#cross-cutting-behaviors) - [`"???"` (MISSING) sentinel](#-missing-sentinel) - [Env interpolation](#env-interpolation) @@ -60,6 +61,7 @@ After validation, the TOML dict is converted to a Hydra override list by [`build [trainer.callbacks.grad_clip] # clip_norm + force_finite [checkpoint] # load_path, save_iter, key-skip blocklist [dataloader_train] # top-level scalars only +[custom] # free-form, project-owned escape hatch (opaque to the framework) ``` The full pipeline (dataloader class, dataset wiring, model_instance LazyCall, etc.) lives in the experiment SKU Python file under `cosmos_framework/configs/base/experiment/sft/.py`. The TOML only surfaces values the recipe author wants users to tune. @@ -225,6 +227,35 @@ Top-level dataloader scalars only. The dataloader's class (LazyCall) and full pi | `max_sequence_length` | `null` | Cap on tokens per packed sequence. Remapped to `max_tokens` on the VLM `DataPackerDataLoader`. `null` = no per-token cap. | | `seed` | `42` | Dataloader RNG seed. **VFM only** — skipped on VLM (DataPackerDataLoader has no `seed` ctor kwarg). | +## `[custom]` (free-form escape hatch) + +`[custom]` is the **one** section the framework does not model, validate, or remap. It exists so a project built on cosmos-framework can carry its own config (dataset paths, sampling ratios, pairing constraints, …) in the **same** TOML as the framework training knobs — instead of a second sidecar file. + +Rules: + +- **Arbitrary nested content** is allowed and passes through verbatim: scalars, sub-tables (`[custom.a.b]`), and arrays-of-tables (`[[custom.items]]`). The framework never validates *inside* `[custom]` (the schema field is a plain `dict[str, Any]`). +- It is the **only** top-level key exempt from the `extra="forbid"` typo guard's *contents*. Every other section — and any unknown top-level key that isn't `custom` — still raises `ValidationError`. +- It is **not** routed through `PATH_REMAPS`. Instead of per-leaf Hydra overrides, the whole `[custom]` table is attached verbatim onto a top-level `custom` node on the resolved `Config`. +- The resolved node is reachable two ways: + - **OmegaConf/Hydra interpolation** from an experiment recipe — e.g. a LazyCall arg `config="${custom}"` or `"${custom.some_key}"` resolves to the custom content; and + - **attribute access** — `config.custom`. +- It converts back to a plain dict via `OmegaConf.to_container(config.custom, resolve=True)`, so a project can run `MyProjectConfig.model_validate()`. + +When `[custom]` is absent, `config.custom` is a default-empty node (`{}`) — `${custom}` still resolves and existing TOMLs are unaffected. + +```toml +[custom] +your_custom_files = "custom_value" +``` + +A project's experiment recipe then wires its data pipeline straight from `[custom]`, e.g.: + +```python +config = L(TrainingDatasetConfig.model_validate)("${custom}") +``` + +so a single TOML drives both the framework training and the project data pipeline. + ## Cross-cutting behaviors ### `"???"` (MISSING) sentinel @@ -289,9 +320,9 @@ A few useful knobs aren't currently modeled by `SFTExperimentConfig` because the 1. Reads the TOML with `tomllib`. 2. Validates the parsed dict against `SFTExperimentConfig` (raises `ValidationError` on unknown keys). 3. Picks the base config from `[job].task`: `TASK_TO_BASE_CONFIG["vfm"|"vlm"]`. -4. Calls `build_hydra_overrides(raw)` to produce a `["--", "experiment=", "k.p=v", …]` list with per-task remaps applied and MISSING values filtered. +4. Calls `build_hydra_overrides(raw)` to produce a `["--", "experiment=", "k.p=v", …]` list with per-task remaps applied and MISSING values filtered. `[custom]` is skipped here (it is injected verbatim in step 6, not per-leaf-remapped). 5. Appends `extra_overrides` (CLI tail) so they take precedence over the TOML. -6. Calls `cosmos_framework.utils.config.load_config(base_config_path, overrides)`, which imports the base config module (running `make_config()` to register every config group and import every experiment SKU's `cs.store(group="experiment", …)`), then runs `override(config, overrides)` — Hydra `compose` resolves the `experiment=` selector against `ConfigStore` and applies the dotted-path overrides. +6. Imports the base config module and runs `make_config()` (registers every config group and imports every experiment SKU's `cs.store(group="experiment", …)`), sets the TOML's `[custom]` table onto `config.custom` (so it is part of the single OmegaConf tree Hydra resolves — that's what lets `${custom}` interpolations resolve), then runs `override(config, overrides)` — Hydra `compose` resolves the `experiment=` selector against `ConfigStore` and applies the dotted-path overrides. The returned `Config` is ready for `launch()`. diff --git a/tests/toml_custom_section_test.py b/tests/toml_custom_section_test.py new file mode 100644 index 0000000..77a23ef --- /dev/null +++ b/tests/toml_custom_section_test.py @@ -0,0 +1,247 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Loader tests for the free-form ``[custom]`` escape-hatch section of the SFT TOML.""" + +from __future__ import annotations + +from pathlib import Path + +import attrs +import pytest +from pydantic import ValidationError + +from cosmos_framework.configs.toml_config.sft_config import SFTExperimentConfig +from cosmos_framework.configs.toml_config.toml_config_helper import build_hydra_overrides + +# Representative payload: scalars, a nested sub-table, and an array-of-tables. +_CUSTOM_PAYLOAD = { + "scalar_int": 5, + "scalar_str": "hello", + "flag": True, + "ratio": 0.3, + "sampling": {"bug_ratio": 0.3, "nested": {"deep": 1}}, + "items": [ + {"path": "/data/a", "weight": 1.0}, + {"path": "/data/b", "weight": 2.0}, + ], +} + + +# --------------------------------------------------------------------------- # +# 1. pydantic schema validation # +# --------------------------------------------------------------------------- # +class TestSchemaValidation: + def test_custom_section_validates_arbitrary_nested_content(self) -> None: + """Arbitrary nested [custom] content passes through untouched.""" + raw = { + "job": {"task": "vfm", "experiment": "vision_sft_nano"}, + "custom": _CUSTOM_PAYLOAD, + } + cfg = SFTExperimentConfig.model_validate(raw) + # The framework stores it verbatim — no coercion, no inner validation. + assert cfg.custom == _CUSTOM_PAYLOAD + + def test_no_custom_section_defaults_empty(self) -> None: + cfg = SFTExperimentConfig.model_validate({"job": {"task": "vfm", "experiment": "vision_sft_nano"}}) + assert cfg.custom == {} + + def test_unknown_top_level_key_raises(self) -> None: + """Any unknown top-level section that is NOT `custom` still raises.""" + with pytest.raises(ValidationError): + SFTExperimentConfig.model_validate( + { + "job": {"task": "vfm", "experiment": "vision_sft_nano"}, + "bogus_section": {"x": 1}, + } + ) + + def test_unknown_key_inside_optimizer_raises(self) -> None: + """A typo inside a KNOWN section is still a hard error (extra='forbid').""" + with pytest.raises(ValidationError): + SFTExperimentConfig.model_validate( + { + "job": {"task": "vfm", "experiment": "vision_sft_nano"}, + "optimizer": {"lr": 1.0e-4, "not_a_real_key": 1}, + } + ) + + def test_custom_does_not_loosen_sibling_validation(self) -> None: + """Presence of [custom] must not relax extra='forbid' elsewhere.""" + with pytest.raises(ValidationError): + SFTExperimentConfig.model_validate( + { + "job": {"task": "vfm", "experiment": "vision_sft_nano"}, + "custom": _CUSTOM_PAYLOAD, + "trainer": {"max_iter": 10, "typo_here": True}, + } + ) + + +# --------------------------------------------------------------------------- # +# 2. build_hydra_overrides must NOT emit [custom] as per-leaf overrides # +# --------------------------------------------------------------------------- # +class TestBuildHydraOverrides: + def test_custom_not_emitted_as_overrides(self) -> None: + raw = { + "job": {"task": "vfm", "experiment": "vision_sft_nano"}, + "optimizer": {"lr": 1.0e-5}, + "custom": _CUSTOM_PAYLOAD, + } + overrides = build_hydra_overrides(raw) + # Nothing under custom (verbatim or remapped) should appear. + assert all("custom" not in o for o in overrides), overrides + + def test_other_keys_still_emitted(self) -> None: + raw = { + "job": {"task": "vfm", "experiment": "vision_sft_nano"}, + "optimizer": {"lr": 1.0e-5}, + "custom": {"a": 1}, + } + overrides = build_hydra_overrides(raw) + assert "experiment=vision_sft_nano" in overrides + assert any(o.startswith("optimizer.lr=") for o in overrides), overrides + + +# --------------------------------------------------------------------------- # +# 3. interpolation (${custom}) + attribute access via the real override path # +# --------------------------------------------------------------------------- # +@attrs.define(slots=False) +class _ProbeConfig: + """Minimal attrs config (no ``defaults`` field) so ``override`` resolves it + without needing registered config groups / the training stack.""" + + custom: dict = attrs.field(factory=dict) + dataloader_train: object = None + + +class TestCustomInterpolationAndAccess: + def test_interpolation_resolves_and_attribute_access(self) -> None: + from omegaconf import OmegaConf + + from cosmos_framework.utils.config_helper import override + from cosmos_framework.utils.lazy_config import instantiate + + # builtins.dict is a LazyCall-shaped, instantiate-locatable target. + probe = _ProbeConfig( + custom=dict(_CUSTOM_PAYLOAD), + dataloader_train={ + "_target_": "builtins.dict", + "whole": "${custom}", + "leaf": "${custom.scalar_int}", + "deep": "${custom.sampling.nested.deep}", + }, + ) + + resolved = override(probe, ["--"]) + + # attribute access + to_container round-trips to the plain dict. + assert OmegaConf.to_container(resolved.custom, resolve=True) == _CUSTOM_PAYLOAD + + # ${custom...} interpolations resolved against the top-level node. + dl = resolved.dataloader_train + assert OmegaConf.to_container(dl["whole"], resolve=True) == _CUSTOM_PAYLOAD + assert dl["leaf"] == _CUSTOM_PAYLOAD["scalar_int"] + assert dl["deep"] == _CUSTOM_PAYLOAD["sampling"]["nested"]["deep"] + + # instantiate() yields the resolved content. + obj = instantiate(resolved.dataloader_train) + assert obj["leaf"] == _CUSTOM_PAYLOAD["scalar_int"] + assert obj["deep"] == _CUSTOM_PAYLOAD["sampling"]["nested"]["deep"] + assert OmegaConf.to_container(obj["whole"], resolve=True) == _CUSTOM_PAYLOAD + + +# --------------------------------------------------------------------------- # +# 4. end-to-end load_experiment_from_toml on the shipped vision_sft_nano recipe # +# --------------------------------------------------------------------------- # +_BASE_TOML = """\ +[job] +task = "vfm" +experiment = "vision_sft_nano" +project = "cosmos3" +group = "sft" +name = "toml_custom_section_test" +wandb_mode = "disabled" + +[model.tokenizer] +vae_path = "${oc.env:WAN_VAE_PATH}" + +[checkpoint] +load_path = "${oc.env:BASE_CHECKPOINT_PATH}" +""" + +_CUSTOM_TOML_BLOCK = """\ + +[custom] +scalar_int = 5 +scalar_str = "hello" +flag = true +ratio = 0.3 + +[custom.sampling] +bug_ratio = 0.3 + +[custom.sampling.nested] +deep = 1 + +[[custom.items]] +path = "/data/a" +weight = 1.0 + +[[custom.items]] +path = "/data/b" +weight = 2.0 +""" + + +def _load_or_skip(toml_path: Path): + """Run the real loader, skipping if the training stack can't be imported.""" + from cosmos_framework.configs.toml_config.sft_config import load_experiment_from_toml + + try: + return load_experiment_from_toml(str(toml_path)) + except ImportError as exc: # pragma: no cover — env-dependent + pytest.skip(f"training stack not importable here: {exc!r}") + + +@pytest.fixture +def _dummy_recipe_env(monkeypatch: pytest.MonkeyPatch) -> None: + # vision_sft_nano interpolates these env vars into path strings at resolve time. + monkeypatch.setenv("DATASET_PATH", "/tmp/dummy_dataset") + monkeypatch.setenv("WAN_VAE_PATH", "/tmp/dummy_vae.pth") + monkeypatch.setenv("BASE_CHECKPOINT_PATH", "/tmp/dummy_ckpt") + + +class TestEndToEndLoader: + def test_load_with_custom_section(self, tmp_path: Path, _dummy_recipe_env: None) -> None: + from omegaconf import OmegaConf + + toml_path = tmp_path / "with_custom.toml" + toml_path.write_text(_BASE_TOML + _CUSTOM_TOML_BLOCK) + + config = _load_or_skip(toml_path) + + expected = { + "scalar_int": 5, + "scalar_str": "hello", + "flag": True, + "ratio": 0.3, + "sampling": {"bug_ratio": 0.3, "nested": {"deep": 1}}, + "items": [ + {"path": "/data/a", "weight": 1.0}, + {"path": "/data/b", "weight": 2.0}, + ], + } + # Reachable by attribute access and convertible to a plain dict, so a + # project can run MyProjectConfig.model_validate(). + assert OmegaConf.to_container(config.custom, resolve=True) == expected + + def test_load_without_custom_section_defaults_empty(self, tmp_path: Path, _dummy_recipe_env: None) -> None: + from omegaconf import OmegaConf + + toml_path = tmp_path / "no_custom.toml" + toml_path.write_text(_BASE_TOML) + + config = _load_or_skip(toml_path) + + assert OmegaConf.to_container(config.custom, resolve=True) == {} From 0be6429c1cf390f94d247b72fcb2b7a339d420b6 Mon Sep 17 00:00:00 2001 From: Alex Peng Date: Thu, 25 Jun 2026 15:55:07 +0800 Subject: [PATCH 2/6] Refactor SFT config loading to support custom section injection Updated the `load_experiment_from_toml` function to utilize `load_config` with a `pre_override` hook for injecting the `[custom]` section into the config before applying overrides. This ensures that custom configurations are part of the OmegaConf tree resolved by Hydra. Additionally, modified the `load_config` and `_load_py_config` functions to accept the `pre_override` parameter for enhanced configurability. Documentation for SFT configuration has been updated to reflect these changes. --- .../configs/toml_config/sft_config.py | 21 ++++++-------- cosmos_framework/utils/config.py | 29 ++++++++++++++++--- docs/sft_config.md | 2 +- 3 files changed, 35 insertions(+), 17 deletions(-) diff --git a/cosmos_framework/configs/toml_config/sft_config.py b/cosmos_framework/configs/toml_config/sft_config.py index 611b41c..e35461a 100644 --- a/cosmos_framework/configs/toml_config/sft_config.py +++ b/cosmos_framework/configs/toml_config/sft_config.py @@ -746,18 +746,15 @@ def load_experiment_from_toml( overrides.append(o) # Import lazily so this module stays cheap to import in non-training contexts. - import importlib - - from cosmos_framework.utils.config_helper import get_config_module, override - - # Set [custom] on the base Config before override() so it is part of the - # OmegaConf tree Hydra resolves (enables ${custom} interpolation) and lands - # verbatim rather than being per-leaf-remapped by build_hydra_overrides. - config_module = get_config_module(base_config_path) - config = importlib.import_module(config_module).make_config() + from cosmos_framework.utils.config import load_config + # Inject [custom] before override() (via load_config's pre_override hook) so + # it is part of the OmegaConf tree Hydra resolves (enables ${custom} + # interpolation) and lands verbatim rather than being per-leaf-remapped. custom = raw.get("custom") - if custom: - config.custom = custom - return override(config, overrides) + def _inject_custom(config: Any) -> None: + if custom: + config.custom = custom + + return load_config(base_config_path, overrides, pre_override=_inject_custom) diff --git a/cosmos_framework/utils/config.py b/cosmos_framework/utils/config.py index c59d689..aec9343 100644 --- a/cosmos_framework/utils/config.py +++ b/cosmos_framework/utils/config.py @@ -8,7 +8,7 @@ import importlib import os import time -from typing import Any, Dict, Optional, Type, TypeVar, Union +from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union import attrs import torch @@ -517,7 +517,18 @@ def validate(self) -> None: assert self.job.name != "" -def load_config(config_path: str, opts: list[str], enable_one_logger: bool = False) -> Config: +def load_config( + config_path: str, + opts: list[str], + enable_one_logger: bool = False, + pre_override: Optional[Callable[[Config], None]] = None, +) -> Config: + """Load a config from a ``.yaml`` or ``.py`` path and apply ``opts``. + + ``pre_override`` is an optional hook called on the freshly-built config + (after ``make_config()``, before ``override()``) — use it to mutate the + config so the change is part of the OmegaConf tree Hydra resolves. + """ from cosmos_framework.utils.serialization import from_yaml, load_callable t1 = time.monotonic_ns() @@ -528,9 +539,11 @@ def load_config(config_path: str, opts: list[str], enable_one_logger: bool = Fal from cosmos_framework.utils.config_helper import override + if pre_override is not None: + pre_override(config) config = override(config, opts, remove_defaults=True) else: - config = _load_py_config(config_path, opts, validate=False) + config = _load_py_config(config_path, opts, validate=False, pre_override=pre_override) if enable_one_logger: try: @@ -549,7 +562,12 @@ def load_config(config_path: str, opts: list[str], enable_one_logger: bool = Fal return config -def _load_py_config(config_path: str, opts: list[str], validate: bool = True) -> Config: +def _load_py_config( + config_path: str, + opts: list[str], + validate: bool = True, + pre_override: Optional[Callable[[Config], None]] = None, +) -> Config: # NOTE: circular dependency from cosmos_framework.utils.config_helper import get_config_module, override @@ -563,6 +581,9 @@ def _load_py_config(config_path: str, opts: list[str], validate: bool = True) -> t2 = time.monotonic_ns() logging.debug(f"importlib.import_module: took {(t2 - t1) / 1e6:.2f}ms") + if pre_override is not None: + pre_override(config) + t1 = time.monotonic_ns() config = override(config, opts) t2 = time.monotonic_ns() diff --git a/docs/sft_config.md b/docs/sft_config.md index 6b79dab..41b6b60 100644 --- a/docs/sft_config.md +++ b/docs/sft_config.md @@ -322,7 +322,7 @@ A few useful knobs aren't currently modeled by `SFTExperimentConfig` because the 3. Picks the base config from `[job].task`: `TASK_TO_BASE_CONFIG["vfm"|"vlm"]`. 4. Calls `build_hydra_overrides(raw)` to produce a `["--", "experiment=", "k.p=v", …]` list with per-task remaps applied and MISSING values filtered. `[custom]` is skipped here (it is injected verbatim in step 6, not per-leaf-remapped). 5. Appends `extra_overrides` (CLI tail) so they take precedence over the TOML. -6. Imports the base config module and runs `make_config()` (registers every config group and imports every experiment SKU's `cs.store(group="experiment", …)`), sets the TOML's `[custom]` table onto `config.custom` (so it is part of the single OmegaConf tree Hydra resolves — that's what lets `${custom}` interpolations resolve), then runs `override(config, overrides)` — Hydra `compose` resolves the `experiment=` selector against `ConfigStore` and applies the dotted-path overrides. +6. Calls `cosmos_framework.utils.config.load_config(base_config_path, overrides, pre_override=…)`, which imports the base config module and runs `make_config()` (registers every config group and imports every experiment SKU's `cs.store(group="experiment", …)`). The `pre_override` hook sets the TOML's `[custom]` table onto `config.custom` before `override()` runs, so `[custom]` is part of the single OmegaConf tree Hydra resolves — that's what lets `${custom}` interpolations resolve. `override(config, overrides)` then has Hydra `compose` resolve the `experiment=` selector against `ConfigStore` and apply the dotted-path overrides. The returned `Config` is ready for `launch()`. From ae07e2acb30538b6e98871ce0930e316f773b51e Mon Sep 17 00:00:00 2001 From: Alex Peng Date: Thu, 25 Jun 2026 16:01:07 +0800 Subject: [PATCH 3/6] test: relocate config test to be auto-discovered by CI Signed-off-by: Alex Peng --- .../configs/toml_config/sft_config_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename tests/toml_custom_section_test.py => cosmos_framework/configs/toml_config/sft_config_test.py (99%) diff --git a/tests/toml_custom_section_test.py b/cosmos_framework/configs/toml_config/sft_config_test.py similarity index 99% rename from tests/toml_custom_section_test.py rename to cosmos_framework/configs/toml_config/sft_config_test.py index 77a23ef..1b7ee62 100644 --- a/tests/toml_custom_section_test.py +++ b/cosmos_framework/configs/toml_config/sft_config_test.py @@ -160,7 +160,7 @@ def test_interpolation_resolves_and_attribute_access(self) -> None: experiment = "vision_sft_nano" project = "cosmos3" group = "sft" -name = "toml_custom_section_test" +name = "sft_config_custom_test" wandb_mode = "disabled" [model.tokenizer] From dfdf0989c097fe08ae0b8792959e99b2c599f9a5 Mon Sep 17 00:00:00 2001 From: "Alex Peng (Content Tech)" Date: Fri, 26 Jun 2026 13:49:37 +0800 Subject: [PATCH 4/6] feat(sft): inject TOML [custom] section onto config.custom post-load --- cosmos_framework/configs/base/config.py | 3 - .../configs/base/vlm/defaults/config.py | 3 - .../configs/toml_config/sft_config.py | 35 +++++------ .../configs/toml_config/sft_config_test.py | 63 ++----------------- cosmos_framework/utils/config.py | 18 +----- 5 files changed, 25 insertions(+), 97 deletions(-) diff --git a/cosmos_framework/configs/base/config.py b/cosmos_framework/configs/base/config.py index 5293c23..e766c5c 100644 --- a/cosmos_framework/configs/base/config.py +++ b/cosmos_framework/configs/base/config.py @@ -25,9 +25,6 @@ class DataSetting: @attrs.define(slots=False) class Config(config.Config): data_setting: DataSetting = attrs.field(factory=DataSetting) - # Free-form, project-owned escape hatch fed by the SFT TOML's [custom] section. - # Default-empty so ${custom} interpolation and config.custom always resolve. - custom: dict = attrs.field(factory=dict) defaults: List[Any] = attrs.field( factory=lambda: [ "_self_", diff --git a/cosmos_framework/configs/base/vlm/defaults/config.py b/cosmos_framework/configs/base/vlm/defaults/config.py index 6145903..2128072 100644 --- a/cosmos_framework/configs/base/vlm/defaults/config.py +++ b/cosmos_framework/configs/base/vlm/defaults/config.py @@ -52,9 +52,6 @@ class DataSetting: class Config(config.Config): policy: PolicyConfig = PolicyConfig() data_setting: DataSetting = DataSetting() - # Free-form, project-owned escape hatch fed by the SFT TOML's [custom] section. - # Default-empty so ${custom} interpolation and config.custom always resolve. - custom: dict = attrs.field(factory=dict) defaults: List[Any] = attrs.field( factory=lambda: [ "_self_", diff --git a/cosmos_framework/configs/toml_config/sft_config.py b/cosmos_framework/configs/toml_config/sft_config.py index e35461a..5d17f2b 100644 --- a/cosmos_framework/configs/toml_config/sft_config.py +++ b/cosmos_framework/configs/toml_config/sft_config.py @@ -674,7 +674,8 @@ class SFTExperimentConfig(BaseModel): description=( "Free-form, project-owned escape hatch. Arbitrary nested content " "passes through verbatim — the framework never validates inside it. " - "Reachable as config.custom and via '${custom}' interpolation." + "Injected onto the loaded config as ``config.custom`` after Hydra " + "resolution; specify concrete values here (no ${...} interpolation)." ), ) @@ -703,14 +704,15 @@ def load_experiment_from_toml( The load then: - 1. Imports the base config module and runs ``make_config()`` (registers - config groups + experiment modules). - 2. Sets the TOML's ``[custom]`` table (if any) onto ``config.custom`` so it - is part of the OmegaConf tree Hydra resolves — kept out of - ``build_hydra_overrides`` so it lands verbatim, not per-leaf-remapped. - 3. Runs ``override(config, overrides)`` — Hydra ``compose`` resolves the - ``experiment=`` selector and applies the dotted-path overrides, - followed by ``extra_overrides``. + 1. Runs ``load_config`` — imports the base config module, runs + ``make_config()`` (registers config groups + experiment modules), and + lets Hydra ``compose`` resolve the ``experiment=`` selector and + apply the dotted-path overrides, followed by ``extra_overrides``. + 2. Injects the TOML's ``[custom]`` table (if any) verbatim onto + ``config.custom`` *after* loading — kept out of ``build_hydra_overrides`` + so it lands as-is, not per-leaf-remapped. Because this happens after + Hydra resolution, ``[custom]`` must hold concrete values; ``${...}`` + interpolation against ``custom`` is not supported. Returns the merged ``Config`` instance, ready for ``launch()``. """ @@ -748,13 +750,10 @@ def load_experiment_from_toml( # Import lazily so this module stays cheap to import in non-training contexts. from cosmos_framework.utils.config import load_config - # Inject [custom] before override() (via load_config's pre_override hook) so - # it is part of the OmegaConf tree Hydra resolves (enables ${custom} - # interpolation) and lands verbatim rather than being per-leaf-remapped. - custom = raw.get("custom") + config = load_config(base_config_path, overrides) - def _inject_custom(config: Any) -> None: - if custom: - config.custom = custom - - return load_config(base_config_path, overrides, pre_override=_inject_custom) + # Inject [custom] verbatim after Hydra resolution. Kept off the base config + # schema so the framework-owned hydra configs stay untouched; lands as a + # plain dict reachable via config.custom. + config.custom = raw.get("custom", {}) + return config diff --git a/cosmos_framework/configs/toml_config/sft_config_test.py b/cosmos_framework/configs/toml_config/sft_config_test.py index 1b7ee62..3385b41 100644 --- a/cosmos_framework/configs/toml_config/sft_config_test.py +++ b/cosmos_framework/configs/toml_config/sft_config_test.py @@ -7,7 +7,6 @@ from pathlib import Path -import attrs import pytest from pydantic import ValidationError @@ -104,55 +103,7 @@ def test_other_keys_still_emitted(self) -> None: # --------------------------------------------------------------------------- # -# 3. interpolation (${custom}) + attribute access via the real override path # -# --------------------------------------------------------------------------- # -@attrs.define(slots=False) -class _ProbeConfig: - """Minimal attrs config (no ``defaults`` field) so ``override`` resolves it - without needing registered config groups / the training stack.""" - - custom: dict = attrs.field(factory=dict) - dataloader_train: object = None - - -class TestCustomInterpolationAndAccess: - def test_interpolation_resolves_and_attribute_access(self) -> None: - from omegaconf import OmegaConf - - from cosmos_framework.utils.config_helper import override - from cosmos_framework.utils.lazy_config import instantiate - - # builtins.dict is a LazyCall-shaped, instantiate-locatable target. - probe = _ProbeConfig( - custom=dict(_CUSTOM_PAYLOAD), - dataloader_train={ - "_target_": "builtins.dict", - "whole": "${custom}", - "leaf": "${custom.scalar_int}", - "deep": "${custom.sampling.nested.deep}", - }, - ) - - resolved = override(probe, ["--"]) - - # attribute access + to_container round-trips to the plain dict. - assert OmegaConf.to_container(resolved.custom, resolve=True) == _CUSTOM_PAYLOAD - - # ${custom...} interpolations resolved against the top-level node. - dl = resolved.dataloader_train - assert OmegaConf.to_container(dl["whole"], resolve=True) == _CUSTOM_PAYLOAD - assert dl["leaf"] == _CUSTOM_PAYLOAD["scalar_int"] - assert dl["deep"] == _CUSTOM_PAYLOAD["sampling"]["nested"]["deep"] - - # instantiate() yields the resolved content. - obj = instantiate(resolved.dataloader_train) - assert obj["leaf"] == _CUSTOM_PAYLOAD["scalar_int"] - assert obj["deep"] == _CUSTOM_PAYLOAD["sampling"]["nested"]["deep"] - assert OmegaConf.to_container(obj["whole"], resolve=True) == _CUSTOM_PAYLOAD - - -# --------------------------------------------------------------------------- # -# 4. end-to-end load_experiment_from_toml on the shipped vision_sft_nano recipe # +# 3. end-to-end load_experiment_from_toml on the shipped vision_sft_nano recipe # # --------------------------------------------------------------------------- # _BASE_TOML = """\ [job] @@ -214,8 +165,6 @@ def _dummy_recipe_env(monkeypatch: pytest.MonkeyPatch) -> None: class TestEndToEndLoader: def test_load_with_custom_section(self, tmp_path: Path, _dummy_recipe_env: None) -> None: - from omegaconf import OmegaConf - toml_path = tmp_path / "with_custom.toml" toml_path.write_text(_BASE_TOML + _CUSTOM_TOML_BLOCK) @@ -232,16 +181,14 @@ def test_load_with_custom_section(self, tmp_path: Path, _dummy_recipe_env: None) {"path": "/data/b", "weight": 2.0}, ], } - # Reachable by attribute access and convertible to a plain dict, so a - # project can run MyProjectConfig.model_validate(). - assert OmegaConf.to_container(config.custom, resolve=True) == expected + # Injected verbatim as a plain dict after Hydra resolution, so a project + # can run MyProjectConfig.model_validate(config.custom) directly. + assert config.custom == expected def test_load_without_custom_section_defaults_empty(self, tmp_path: Path, _dummy_recipe_env: None) -> None: - from omegaconf import OmegaConf - toml_path = tmp_path / "no_custom.toml" toml_path.write_text(_BASE_TOML) config = _load_or_skip(toml_path) - assert OmegaConf.to_container(config.custom, resolve=True) == {} + assert config.custom == {} diff --git a/cosmos_framework/utils/config.py b/cosmos_framework/utils/config.py index aec9343..5078fbc 100644 --- a/cosmos_framework/utils/config.py +++ b/cosmos_framework/utils/config.py @@ -8,7 +8,7 @@ import importlib import os import time -from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union +from typing import Any, Dict, Optional, Type, TypeVar, Union import attrs import torch @@ -521,14 +521,8 @@ def load_config( config_path: str, opts: list[str], enable_one_logger: bool = False, - pre_override: Optional[Callable[[Config], None]] = None, ) -> Config: - """Load a config from a ``.yaml`` or ``.py`` path and apply ``opts``. - - ``pre_override`` is an optional hook called on the freshly-built config - (after ``make_config()``, before ``override()``) — use it to mutate the - config so the change is part of the OmegaConf tree Hydra resolves. - """ + """Load a config from a ``.yaml`` or ``.py`` path and apply ``opts``.""" from cosmos_framework.utils.serialization import from_yaml, load_callable t1 = time.monotonic_ns() @@ -539,11 +533,9 @@ def load_config( from cosmos_framework.utils.config_helper import override - if pre_override is not None: - pre_override(config) config = override(config, opts, remove_defaults=True) else: - config = _load_py_config(config_path, opts, validate=False, pre_override=pre_override) + config = _load_py_config(config_path, opts, validate=False) if enable_one_logger: try: @@ -566,7 +558,6 @@ def _load_py_config( config_path: str, opts: list[str], validate: bool = True, - pre_override: Optional[Callable[[Config], None]] = None, ) -> Config: # NOTE: circular dependency from cosmos_framework.utils.config_helper import get_config_module, override @@ -581,9 +572,6 @@ def _load_py_config( t2 = time.monotonic_ns() logging.debug(f"importlib.import_module: took {(t2 - t1) / 1e6:.2f}ms") - if pre_override is not None: - pre_override(config) - t1 = time.monotonic_ns() config = override(config, opts) t2 = time.monotonic_ns() From b2ed29491bd9f6a3c2709c78ce6eba0a750f2c0c Mon Sep 17 00:00:00 2001 From: "Alex Peng (Content Tech)" Date: Fri, 26 Jun 2026 17:29:45 +0800 Subject: [PATCH 5/6] Refine docs for `[custom]` section in SFT config --- docs/sft_config.md | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/docs/sft_config.md b/docs/sft_config.md index 41b6b60..cedd1f6 100644 --- a/docs/sft_config.md +++ b/docs/sft_config.md @@ -229,33 +229,25 @@ Top-level dataloader scalars only. The dataloader's class (LazyCall) and full pi ## `[custom]` (free-form escape hatch) -`[custom]` is the **one** section the framework does not model, validate, or remap. It exists so a project built on cosmos-framework can carry its own config (dataset paths, sampling ratios, pairing constraints, …) in the **same** TOML as the framework training knobs — instead of a second sidecar file. +`[custom]` lets a project carry its own config (dataset paths, sampling ratios, …) in the **same** TOML as the framework knobs. The framework never looks inside it — it's the one section exempt from the `extra="forbid"` typo guard (every other section still rejects unknown keys). -Rules: +How it works: -- **Arbitrary nested content** is allowed and passes through verbatim: scalars, sub-tables (`[custom.a.b]`), and arrays-of-tables (`[[custom.items]]`). The framework never validates *inside* `[custom]` (the schema field is a plain `dict[str, Any]`). -- It is the **only** top-level key exempt from the `extra="forbid"` typo guard's *contents*. Every other section — and any unknown top-level key that isn't `custom` — still raises `ValidationError`. -- It is **not** routed through `PATH_REMAPS`. Instead of per-leaf Hydra overrides, the whole `[custom]` table is attached verbatim onto a top-level `custom` node on the resolved `Config`. -- The resolved node is reachable two ways: - - **OmegaConf/Hydra interpolation** from an experiment recipe — e.g. a LazyCall arg `config="${custom}"` or `"${custom.some_key}"` resolves to the custom content; and - - **attribute access** — `config.custom`. -- It converts back to a plain dict via `OmegaConf.to_container(config.custom, resolve=True)`, so a project can run `MyProjectConfig.model_validate()`. - -When `[custom]` is absent, `config.custom` is a default-empty node (`{}`) — `${custom}` still resolves and existing TOMLs are unaffected. +- **Arbitrary nested content** passes through verbatim — scalars, sub-tables (`[custom.a.b]`), arrays-of-tables (`[[custom.items]]`). +- It does **not** go through Hydra. After `load_config` finishes, the table is attached as a plain `dict` via `config.custom = raw.get("custom", {})` (or `{}` when absent — reading `config.custom` is always safe). +- So values must be **concrete**: `${custom}` interpolation is **not** supported, and `config.custom` is **not** part of `config.to_dict()` / serialized config dumps. ```toml [custom] your_custom_files = "custom_value" ``` -A project's experiment recipe then wires its data pipeline straight from `[custom]`, e.g.: +Read it directly to wire your own pipeline: ```python -config = L(TrainingDatasetConfig.model_validate)("${custom}") +project_cfg = TrainingDatasetConfig.model_validate(config.custom) ``` -so a single TOML drives both the framework training and the project data pipeline. - ## Cross-cutting behaviors ### `"???"` (MISSING) sentinel @@ -320,9 +312,10 @@ A few useful knobs aren't currently modeled by `SFTExperimentConfig` because the 1. Reads the TOML with `tomllib`. 2. Validates the parsed dict against `SFTExperimentConfig` (raises `ValidationError` on unknown keys). 3. Picks the base config from `[job].task`: `TASK_TO_BASE_CONFIG["vfm"|"vlm"]`. -4. Calls `build_hydra_overrides(raw)` to produce a `["--", "experiment=", "k.p=v", …]` list with per-task remaps applied and MISSING values filtered. `[custom]` is skipped here (it is injected verbatim in step 6, not per-leaf-remapped). +4. Calls `build_hydra_overrides(raw)` to produce a `["--", "experiment=", "k.p=v", …]` list with per-task remaps applied and MISSING values filtered. `[custom]` is skipped here (it is injected verbatim in step 7, not per-leaf-remapped). 5. Appends `extra_overrides` (CLI tail) so they take precedence over the TOML. -6. Calls `cosmos_framework.utils.config.load_config(base_config_path, overrides, pre_override=…)`, which imports the base config module and runs `make_config()` (registers every config group and imports every experiment SKU's `cs.store(group="experiment", …)`). The `pre_override` hook sets the TOML's `[custom]` table onto `config.custom` before `override()` runs, so `[custom]` is part of the single OmegaConf tree Hydra resolves — that's what lets `${custom}` interpolations resolve. `override(config, overrides)` then has Hydra `compose` resolve the `experiment=` selector against `ConfigStore` and apply the dotted-path overrides. +6. Calls `cosmos_framework.utils.config.load_config(base_config_path, overrides)`, which imports the base config module and runs `make_config()` (registers every config group and imports every experiment SKU's `cs.store(group="experiment", …)`), then `override(config, overrides)` has Hydra `compose` resolve the `experiment=` selector against `ConfigStore` and apply the dotted-path overrides. +7. Injects `[custom]` after loading: `config.custom = raw.get("custom", {})`. This runs **after** Hydra resolution, so it leaves the framework-owned base config and `load_config` untouched and lands as a plain `dict` (no `${custom}` interpolation; not part of serialized config dumps). The returned `Config` is ready for `launch()`. From c7d0c3837cd1d4d42ff2db2289bace01c5e59a16 Mon Sep 17 00:00:00 2001 From: "Alex Peng (Content Tech)" Date: Fri, 26 Jun 2026 17:54:16 +0800 Subject: [PATCH 6/6] minor fix for docs --- docs/sft_config.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sft_config.md b/docs/sft_config.md index cedd1f6..8df6e91 100644 --- a/docs/sft_config.md +++ b/docs/sft_config.md @@ -315,7 +315,7 @@ A few useful knobs aren't currently modeled by `SFTExperimentConfig` because the 4. Calls `build_hydra_overrides(raw)` to produce a `["--", "experiment=", "k.p=v", …]` list with per-task remaps applied and MISSING values filtered. `[custom]` is skipped here (it is injected verbatim in step 7, not per-leaf-remapped). 5. Appends `extra_overrides` (CLI tail) so they take precedence over the TOML. 6. Calls `cosmos_framework.utils.config.load_config(base_config_path, overrides)`, which imports the base config module and runs `make_config()` (registers every config group and imports every experiment SKU's `cs.store(group="experiment", …)`), then `override(config, overrides)` has Hydra `compose` resolve the `experiment=` selector against `ConfigStore` and apply the dotted-path overrides. -7. Injects `[custom]` after loading: `config.custom = raw.get("custom", {})`. This runs **after** Hydra resolution, so it leaves the framework-owned base config and `load_config` untouched and lands as a plain `dict` (no `${custom}` interpolation; not part of serialized config dumps). +7. Injects `[custom]` after loading: `config.custom = raw.get("custom", {})`. This runs **after** Hydra resolution, so it lands as a plain `dict` (no `${custom}` interpolation; not part of serialized config dumps). The returned `Config` is ready for `launch()`.