-
Notifications
You must be signed in to change notification settings - Fork 48
Add custom section support to SFT TOML configuration #57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
foreverlms
merged 11 commits into
NVIDIA:main
from
main-voice:feat/support-custom-config-fields
Jun 26, 2026
Merged
Changes from 3 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
a19b143
Add custom section support to SFT TOML configuration
main-voice 0be6429
Refactor SFT config loading to support custom section injection
main-voice ae07e2a
test: relocate config test to be auto-discovered by CI
main-voice e0dfdc0
Merge branch 'main' into feat/support-custom-config-fields
foreverlms dfdf098
feat(sft): inject TOML [custom] section onto config.custom post-load
main-voice cf1520f
Merge branch 'feat/support-custom-config-fields' of github.com:main-v…
main-voice db397bf
Merge branch 'main' into feat/support-custom-config-fields
lfengad b2ed294
Refine docs for `[custom]` section in SFT config
main-voice 2638608
Merge branch 'feat/support-custom-config-fields' of github.com:main-v…
main-voice c7d0c38
minor fix for docs
main-voice 57e2237
Merge branch 'main' into feat/support-custom-config-fields
main-voice File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
247 changes: 247 additions & 0 deletions
247
cosmos_framework/configs/toml_config/sft_config_test.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,247 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: OpenMDW-1.1 | ||
|
|
||
| """Loader tests for the free-form ``[custom]`` escape-hatch section of the SFT TOML.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from pathlib import Path | ||
|
|
||
| import attrs | ||
| import pytest | ||
| from pydantic import ValidationError | ||
|
|
||
| from cosmos_framework.configs.toml_config.sft_config import SFTExperimentConfig | ||
| from cosmos_framework.configs.toml_config.toml_config_helper import build_hydra_overrides | ||
|
|
||
| # Representative payload: scalars, a nested sub-table, and an array-of-tables. | ||
| _CUSTOM_PAYLOAD = { | ||
| "scalar_int": 5, | ||
| "scalar_str": "hello", | ||
| "flag": True, | ||
| "ratio": 0.3, | ||
| "sampling": {"bug_ratio": 0.3, "nested": {"deep": 1}}, | ||
| "items": [ | ||
| {"path": "/data/a", "weight": 1.0}, | ||
| {"path": "/data/b", "weight": 2.0}, | ||
| ], | ||
| } | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- # | ||
| # 1. pydantic schema validation # | ||
| # --------------------------------------------------------------------------- # | ||
| class TestSchemaValidation: | ||
| def test_custom_section_validates_arbitrary_nested_content(self) -> None: | ||
| """Arbitrary nested [custom] content passes through untouched.""" | ||
| raw = { | ||
| "job": {"task": "vfm", "experiment": "vision_sft_nano"}, | ||
| "custom": _CUSTOM_PAYLOAD, | ||
| } | ||
| cfg = SFTExperimentConfig.model_validate(raw) | ||
| # The framework stores it verbatim — no coercion, no inner validation. | ||
| assert cfg.custom == _CUSTOM_PAYLOAD | ||
|
|
||
| def test_no_custom_section_defaults_empty(self) -> None: | ||
| cfg = SFTExperimentConfig.model_validate({"job": {"task": "vfm", "experiment": "vision_sft_nano"}}) | ||
| assert cfg.custom == {} | ||
|
|
||
| def test_unknown_top_level_key_raises(self) -> None: | ||
| """Any unknown top-level section that is NOT `custom` still raises.""" | ||
| with pytest.raises(ValidationError): | ||
| SFTExperimentConfig.model_validate( | ||
| { | ||
| "job": {"task": "vfm", "experiment": "vision_sft_nano"}, | ||
| "bogus_section": {"x": 1}, | ||
| } | ||
| ) | ||
|
|
||
| def test_unknown_key_inside_optimizer_raises(self) -> None: | ||
| """A typo inside a KNOWN section is still a hard error (extra='forbid').""" | ||
| with pytest.raises(ValidationError): | ||
| SFTExperimentConfig.model_validate( | ||
| { | ||
| "job": {"task": "vfm", "experiment": "vision_sft_nano"}, | ||
| "optimizer": {"lr": 1.0e-4, "not_a_real_key": 1}, | ||
| } | ||
| ) | ||
|
|
||
| def test_custom_does_not_loosen_sibling_validation(self) -> None: | ||
| """Presence of [custom] must not relax extra='forbid' elsewhere.""" | ||
| with pytest.raises(ValidationError): | ||
| SFTExperimentConfig.model_validate( | ||
| { | ||
| "job": {"task": "vfm", "experiment": "vision_sft_nano"}, | ||
| "custom": _CUSTOM_PAYLOAD, | ||
| "trainer": {"max_iter": 10, "typo_here": True}, | ||
| } | ||
| ) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- # | ||
| # 2. build_hydra_overrides must NOT emit [custom] as per-leaf overrides # | ||
| # --------------------------------------------------------------------------- # | ||
| class TestBuildHydraOverrides: | ||
| def test_custom_not_emitted_as_overrides(self) -> None: | ||
| raw = { | ||
| "job": {"task": "vfm", "experiment": "vision_sft_nano"}, | ||
| "optimizer": {"lr": 1.0e-5}, | ||
| "custom": _CUSTOM_PAYLOAD, | ||
| } | ||
| overrides = build_hydra_overrides(raw) | ||
| # Nothing under custom (verbatim or remapped) should appear. | ||
| assert all("custom" not in o for o in overrides), overrides | ||
|
|
||
| def test_other_keys_still_emitted(self) -> None: | ||
| raw = { | ||
| "job": {"task": "vfm", "experiment": "vision_sft_nano"}, | ||
| "optimizer": {"lr": 1.0e-5}, | ||
| "custom": {"a": 1}, | ||
| } | ||
| overrides = build_hydra_overrides(raw) | ||
| assert "experiment=vision_sft_nano" in overrides | ||
| assert any(o.startswith("optimizer.lr=") for o in overrides), overrides | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- # | ||
| # 3. interpolation (${custom}) + attribute access via the real override path # | ||
| # --------------------------------------------------------------------------- # | ||
| @attrs.define(slots=False) | ||
| class _ProbeConfig: | ||
| """Minimal attrs config (no ``defaults`` field) so ``override`` resolves it | ||
| without needing registered config groups / the training stack.""" | ||
|
|
||
| custom: dict = attrs.field(factory=dict) | ||
| dataloader_train: object = None | ||
|
|
||
|
|
||
| class TestCustomInterpolationAndAccess: | ||
|
foreverlms marked this conversation as resolved.
Outdated
|
||
| def test_interpolation_resolves_and_attribute_access(self) -> None: | ||
| from omegaconf import OmegaConf | ||
|
|
||
| from cosmos_framework.utils.config_helper import override | ||
| from cosmos_framework.utils.lazy_config import instantiate | ||
|
|
||
| # builtins.dict is a LazyCall-shaped, instantiate-locatable target. | ||
| probe = _ProbeConfig( | ||
| custom=dict(_CUSTOM_PAYLOAD), | ||
| dataloader_train={ | ||
| "_target_": "builtins.dict", | ||
| "whole": "${custom}", | ||
| "leaf": "${custom.scalar_int}", | ||
| "deep": "${custom.sampling.nested.deep}", | ||
| }, | ||
| ) | ||
|
|
||
| resolved = override(probe, ["--"]) | ||
|
|
||
| # attribute access + to_container round-trips to the plain dict. | ||
| assert OmegaConf.to_container(resolved.custom, resolve=True) == _CUSTOM_PAYLOAD | ||
|
|
||
| # ${custom...} interpolations resolved against the top-level node. | ||
| dl = resolved.dataloader_train | ||
| assert OmegaConf.to_container(dl["whole"], resolve=True) == _CUSTOM_PAYLOAD | ||
| assert dl["leaf"] == _CUSTOM_PAYLOAD["scalar_int"] | ||
| assert dl["deep"] == _CUSTOM_PAYLOAD["sampling"]["nested"]["deep"] | ||
|
|
||
| # instantiate() yields the resolved content. | ||
| obj = instantiate(resolved.dataloader_train) | ||
| assert obj["leaf"] == _CUSTOM_PAYLOAD["scalar_int"] | ||
| assert obj["deep"] == _CUSTOM_PAYLOAD["sampling"]["nested"]["deep"] | ||
| assert OmegaConf.to_container(obj["whole"], resolve=True) == _CUSTOM_PAYLOAD | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- # | ||
| # 4. end-to-end load_experiment_from_toml on the shipped vision_sft_nano recipe # | ||
| # --------------------------------------------------------------------------- # | ||
| _BASE_TOML = """\ | ||
| [job] | ||
| task = "vfm" | ||
| experiment = "vision_sft_nano" | ||
| project = "cosmos3" | ||
| group = "sft" | ||
| name = "sft_config_custom_test" | ||
| wandb_mode = "disabled" | ||
|
|
||
| [model.tokenizer] | ||
| vae_path = "${oc.env:WAN_VAE_PATH}" | ||
|
|
||
| [checkpoint] | ||
| load_path = "${oc.env:BASE_CHECKPOINT_PATH}" | ||
| """ | ||
|
|
||
| _CUSTOM_TOML_BLOCK = """\ | ||
|
|
||
| [custom] | ||
| scalar_int = 5 | ||
| scalar_str = "hello" | ||
| flag = true | ||
| ratio = 0.3 | ||
|
|
||
| [custom.sampling] | ||
| bug_ratio = 0.3 | ||
|
|
||
| [custom.sampling.nested] | ||
| deep = 1 | ||
|
|
||
| [[custom.items]] | ||
| path = "/data/a" | ||
| weight = 1.0 | ||
|
|
||
| [[custom.items]] | ||
| path = "/data/b" | ||
| weight = 2.0 | ||
| """ | ||
|
|
||
|
|
||
| def _load_or_skip(toml_path: Path): | ||
| """Run the real loader, skipping if the training stack can't be imported.""" | ||
| from cosmos_framework.configs.toml_config.sft_config import load_experiment_from_toml | ||
|
|
||
| try: | ||
| return load_experiment_from_toml(str(toml_path)) | ||
| except ImportError as exc: # pragma: no cover — env-dependent | ||
| pytest.skip(f"training stack not importable here: {exc!r}") | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def _dummy_recipe_env(monkeypatch: pytest.MonkeyPatch) -> None: | ||
| # vision_sft_nano interpolates these env vars into path strings at resolve time. | ||
| monkeypatch.setenv("DATASET_PATH", "/tmp/dummy_dataset") | ||
| monkeypatch.setenv("WAN_VAE_PATH", "/tmp/dummy_vae.pth") | ||
| monkeypatch.setenv("BASE_CHECKPOINT_PATH", "/tmp/dummy_ckpt") | ||
|
|
||
|
|
||
| class TestEndToEndLoader: | ||
| def test_load_with_custom_section(self, tmp_path: Path, _dummy_recipe_env: None) -> None: | ||
| from omegaconf import OmegaConf | ||
|
|
||
| toml_path = tmp_path / "with_custom.toml" | ||
| toml_path.write_text(_BASE_TOML + _CUSTOM_TOML_BLOCK) | ||
|
|
||
| config = _load_or_skip(toml_path) | ||
|
|
||
| expected = { | ||
| "scalar_int": 5, | ||
| "scalar_str": "hello", | ||
| "flag": True, | ||
| "ratio": 0.3, | ||
| "sampling": {"bug_ratio": 0.3, "nested": {"deep": 1}}, | ||
| "items": [ | ||
| {"path": "/data/a", "weight": 1.0}, | ||
| {"path": "/data/b", "weight": 2.0}, | ||
| ], | ||
| } | ||
| # Reachable by attribute access and convertible to a plain dict, so a | ||
| # project can run MyProjectConfig.model_validate(<that dict>). | ||
| assert OmegaConf.to_container(config.custom, resolve=True) == expected | ||
|
|
||
| def test_load_without_custom_section_defaults_empty(self, tmp_path: Path, _dummy_recipe_env: None) -> None: | ||
| from omegaconf import OmegaConf | ||
|
|
||
| toml_path = tmp_path / "no_custom.toml" | ||
| toml_path.write_text(_BASE_TOML) | ||
|
|
||
| config = _load_or_skip(toml_path) | ||
|
|
||
| assert OmegaConf.to_container(config.custom, resolve=True) == {} | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.