Skip to content
31 changes: 27 additions & 4 deletions src/megatron/bridge/training/utils/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
CONFIG_FILE = "run_config.yaml"

logger = logging.getLogger(__name__)
_RUNTIME_ONLY_TARGETS = frozenset({"megatron.core.timers.Timers"})


def file_exists(path: str) -> bool:
Expand Down Expand Up @@ -276,7 +277,7 @@ def read_run_config(run_config_filename: str) -> dict[str, Any]:
else:
with open(run_config_filename, "r") as f:
config_dict = yaml.safe_load(f)
config_obj[0] = config_dict
config_obj[0] = _sanitize_run_config_object(config_dict)
except Exception as e:
error_msg = f"ERROR: Unable to load config file {run_config_filename}: {e}"
sys.stderr.write(error_msg + "\n")
Expand All @@ -288,19 +289,21 @@ def read_run_config(run_config_filename: str) -> dict[str, Any]:
if isinstance(config_obj[0], dict) and config_obj[0].get("error", False):
raise RuntimeError(config_obj[0]["msg"])

return config_obj[0]
return _sanitize_run_config_object(config_obj[0])
else:
try:
if MultiStorageClientFeature.is_enabled():
msc = MultiStorageClientFeature.import_package()
with msc.open(run_config_filename, "r") as f:
return yaml.safe_load(f)
config_dict = yaml.safe_load(f)
else:
with open(run_config_filename, "r") as f:
return yaml.safe_load(f)
config_dict = yaml.safe_load(f)
except Exception as e:
raise RuntimeError(f"Unable to load config file {run_config_filename}: {e}") from e

return _sanitize_run_config_object(config_dict)


@lru_cache()
def read_train_state(train_state_filename: str) -> TrainState:
Expand Down Expand Up @@ -351,3 +354,23 @@ def read_train_state(train_state_filename: str) -> TrainState:
return ts
except Exception as e:
raise RuntimeError(f"Unable to load train state file {train_state_filename}: {e}") from e


def _sanitize_run_config_object(obj: Any) -> Any:
"""Remove runtime-only objects from run config dictionaries.

Timers and other runtime constructs are serialized with `_target_` entries
that cannot be recreated without additional context (e.g., constructor
arguments provided at runtime). These objects are not required when loading
a checkpoint configuration, so we replace them with ``None`` to avoid
instantiation errors when the config is processed later.
"""

if isinstance(obj, dict):
target = obj.get("_target_")
if isinstance(target, str) and target in _RUNTIME_ONLY_TARGETS:
return None
return {key: _sanitize_run_config_object(value) for key, value in obj.items()}
if isinstance(obj, list):
return [_sanitize_run_config_object(item) for item in obj]
return obj
70 changes: 57 additions & 13 deletions src/megatron/bridge/utils/instantiate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import copy
import functools
import inspect
import logging
from enum import Enum
from textwrap import dedent
Expand Down Expand Up @@ -236,22 +237,12 @@ def instantiate_node(
if OmegaConf.is_missing(node, key) and is_partial:
continue
value = node[key]
try:
value = instantiate_node(value, mode=mode)
except (ImportError, InstantiationException) as e:
if mode == InstantiationMode.STRICT:
raise InstantiationException(
f"Error instantiating {value} for key {full_key}.{key}: {e}"
) from e
else:
value = None
logging.warning(
f"Error instantiating {value} for key {full_key}.{key}. "
f"Using None instead in lenient mode."
)
value = instantiate_node(value, mode=mode)
kwargs[key] = _convert_node(value)

assert callable(_target_)
# Drop unexpected kwargs in lenient mode or raise in strict mode
kwargs = _filter_kwargs_for_target(_target_, kwargs, full_key, mode)
return _call_target(_target_, partial, args, kwargs, full_key)
else:
dict_items = {}
Expand Down Expand Up @@ -356,6 +347,59 @@ def _convert_target_to_string(t: Any) -> Any:
return t


def _filter_kwargs_for_target(
target: Callable[..., Any] | type,
kwargs: dict[str, Any],
full_key: str,
mode: InstantiationMode,
) -> dict[str, Any]:
"""Drop unexpected keyword arguments for a target and warn.

If the target accepts ``**kwargs`` we forward everything. Otherwise we
inspect the signature and remove keys not present as keyword-capable
parameters, emitting a warning with the dropped keys.
"""
try:
signature = inspect.signature(target)
except (TypeError, ValueError):
# Some builtins or C-extensions may not have an inspectable signature.
return kwargs

parameters = signature.parameters
if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values()):
return kwargs

allowed_keys = {
name
for name, param in parameters.items()
if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
}

unexpected = set(kwargs.keys()) - allowed_keys
if _Keys.ARGS in unexpected:
unexpected.remove(_Keys.ARGS)

if not unexpected:
return kwargs

target_str = _convert_target_to_string(target)
if mode == InstantiationMode.LENIENT:
# Warn and drop the unexpected keys
warning_msg = f"Dropping unexpected config keys for target '{target_str}': {sorted(unexpected)}"
if full_key:
warning_msg += f"\nfull_key: {full_key}"
logging.warning(warning_msg)
filtered = {k: v for k, v in kwargs.items() if k in allowed_keys}
if _Keys.ARGS in kwargs:
filtered[_Keys.ARGS] = kwargs[_Keys.ARGS]
return filtered
else:
msg = f"Unexpected config keys for target '{target_str}': {sorted(unexpected)}"
if full_key:
msg += f"\nfull_key: {full_key}"
raise InstantiationException(msg)


def _prepare_input_dict_or_list(d: Union[dict[Any, Any], list[Any]]) -> Any:
res: Any
if isinstance(d, dict):
Expand Down
25 changes: 25 additions & 0 deletions tests/unit_tests/training/utils/test_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,31 @@ def test_read_run_config_invalid_yaml(self, mock_is_initialized, mock_get_rank):
with pytest.raises(RuntimeError, match="Unable to load config file"):
read_run_config("invalid.yaml")

@patch("megatron.bridge.training.utils.checkpoint_utils.get_rank_safe", return_value=0)
@patch("megatron.bridge.training.utils.checkpoint_utils.torch.distributed.is_initialized", return_value=False)
def test_read_run_config_sanitizes_runtime_only_targets(self, mock_is_initialized, mock_get_rank):
"""Run config should drop runtime-only objects such as timers."""
raw_config = {
"model": {
"timers": {"_target_": "megatron.core.timers.Timers"},
"keep": {"_target_": "some.other.Component", "value": 1},
"nested": [
{"timers": {"_target_": "megatron.core.timers.Timers"}},
{"other": {"_target_": "another.Component", "value": 2}},
],
},
"tokenizer": {"type": "sentencepiece"},
}
config_yaml = yaml.dump(raw_config)

with patch("builtins.open", mock_open(read_data=config_yaml)):
result = read_run_config("config_with_timers.yaml")

assert result["model"]["timers"] is None
assert result["model"]["nested"][0]["timers"] is None
assert result["model"]["keep"]["_target_"] == "some.other.Component"
assert result["model"]["nested"][1]["other"]["_target_"] == "another.Component"

@patch("megatron.bridge.training.utils.checkpoint_utils.get_rank_safe")
@patch("megatron.bridge.training.utils.checkpoint_utils.torch.distributed.is_initialized")
@patch("megatron.bridge.training.utils.checkpoint_utils.torch.load")
Expand Down
106 changes: 98 additions & 8 deletions tests/unit_tests/utils/test_instantiate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
import functools
import logging
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -191,18 +192,14 @@ def test_instantiate_strict_mode_error(self):
with pytest.raises(InstantiationException):
instantiate(config, mode=InstantiationMode.STRICT)

def test_instantiate_lenient_mode_error(self, caplog):
"""Test instantiate in lenient mode with error."""
def test_instantiate_lenient_mode_error(self):
"""In lenient mode, nested resolution errors now propagate (no auto-None)."""
config = {
"_target_": "tests.unit_tests.utils.test_instantiate_utils.TestClass",
"nested": {"_target_": "non.existent.module.Class"},
}
with caplog.at_level(logging.WARNING):
result = instantiate(config, mode=InstantiationMode.LENIENT)

assert isinstance(result, TestClass)
assert result.kwargs["nested"] is None
assert "Error instantiating" in caplog.text
with pytest.raises(InstantiationException, match="Error locating target"):
instantiate(config, mode=InstantiationMode.LENIENT)

def test_instantiate_with_omegaconf_dict(self):
"""Test instantiate with OmegaConf DictConfig."""
Expand Down Expand Up @@ -525,3 +522,96 @@ def test_missing_values_with_partial(self):
actual_result = result(arg2="value2")
expected = {"arg1": "value1", "arg2": "value2", "kwargs": {}}
assert actual_result == expected


class DummyTarget:
def __init__(self, a: int, b: int = 0) -> None:
self.a = a
self.b = b


class KwTarget:
def __init__(self, **kwargs) -> None:
self.kwargs = dict(kwargs)


def _target_qualname(obj) -> str:
return f"{obj.__module__}.{obj.__qualname__}"


def test_drops_unexpected_kwargs_and_warns(caplog: pytest.LogCaptureFixture) -> None:
config = {
"_target_": _target_qualname(DummyTarget),
"a": 10,
"foo": 123, # unexpected key that should be dropped
}

with caplog.at_level(logging.WARNING):
obj = instantiate(config)

assert isinstance(obj, DummyTarget)
assert obj.a == 10
# 'foo' is dropped; 'b' remains default
assert obj.b == 0

# Ensure a warning was emitted mentioning the dropped key
warnings = [rec.getMessage() for rec in caplog.records if rec.levelno == logging.WARNING]
assert any("Dropping unexpected config keys" in m for m in warnings)
assert any("foo" in m for m in warnings)


def test_allows_kwargs_when_target_accepts_var_kwargs(caplog: pytest.LogCaptureFixture) -> None:
config = {
"_target_": _target_qualname(KwTarget),
"foo": 1,
"bar": 2,
}

with caplog.at_level(logging.WARNING):
obj = instantiate(config)

assert isinstance(obj, KwTarget)
assert obj.kwargs == {"foo": 1, "bar": 2}

# No warning should be emitted for **kwargs targets
warnings = [rec.getMessage() for rec in caplog.records if rec.levelno == logging.WARNING]
assert not any("Dropping unexpected config keys" in m for m in warnings)


def test_raises_on_unexpected_kwargs_in_strict_mode() -> None:
config = {
"_target_": _target_qualname(DummyTarget),
"a": 10,
"foo": 123,
}

with pytest.raises(InstantiationException):
instantiate(config, mode=InstantiationMode.STRICT)


class TestEnum(enum.Enum):
A = 1
B = 2


class TestInstantiateEnum:
"""Test instantiation of Enums."""

def test_instantiate_enum_with_args(self):
"""Test instantiating an Enum with _args_."""
config = {
"_target_": "tests.unit_tests.utils.test_instantiate_utils.TestEnum",
"_args_": [1],
}
result = instantiate(config)
assert result == TestEnum.A

def test_instantiate_enum_with_args_lenient(self):
"""Test instantiating an Enum with _args_ in lenient mode (default)."""
config = {
"_target_": "tests.unit_tests.utils.test_instantiate_utils.TestEnum",
"_args_": [2],
}
# This previously failed because _args_ was dropped in lenient mode
result = instantiate(config)
assert result == TestEnum.B