diff --git a/src/megatron/bridge/training/utils/checkpoint_utils.py b/src/megatron/bridge/training/utils/checkpoint_utils.py index 871c3b353d..d1a2720cd4 100644 --- a/src/megatron/bridge/training/utils/checkpoint_utils.py +++ b/src/megatron/bridge/training/utils/checkpoint_utils.py @@ -33,6 +33,7 @@ CONFIG_FILE = "run_config.yaml" logger = logging.getLogger(__name__) +_RUNTIME_ONLY_TARGETS = frozenset({"megatron.core.timers.Timers"}) def file_exists(path: str) -> bool: @@ -276,7 +277,7 @@ def read_run_config(run_config_filename: str) -> dict[str, Any]: else: with open(run_config_filename, "r") as f: config_dict = yaml.safe_load(f) - config_obj[0] = config_dict + config_obj[0] = _sanitize_run_config_object(config_dict) except Exception as e: error_msg = f"ERROR: Unable to load config file {run_config_filename}: {e}" sys.stderr.write(error_msg + "\n") @@ -288,19 +289,21 @@ def read_run_config(run_config_filename: str) -> dict[str, Any]: if isinstance(config_obj[0], dict) and config_obj[0].get("error", False): raise RuntimeError(config_obj[0]["msg"]) - return config_obj[0] + return _sanitize_run_config_object(config_obj[0]) else: try: if MultiStorageClientFeature.is_enabled(): msc = MultiStorageClientFeature.import_package() with msc.open(run_config_filename, "r") as f: - return yaml.safe_load(f) + config_dict = yaml.safe_load(f) else: with open(run_config_filename, "r") as f: - return yaml.safe_load(f) + config_dict = yaml.safe_load(f) except Exception as e: raise RuntimeError(f"Unable to load config file {run_config_filename}: {e}") from e + return _sanitize_run_config_object(config_dict) + @lru_cache() def read_train_state(train_state_filename: str) -> TrainState: @@ -351,3 +354,23 @@ def read_train_state(train_state_filename: str) -> TrainState: return ts except Exception as e: raise RuntimeError(f"Unable to load train state file {train_state_filename}: {e}") from e + + +def _sanitize_run_config_object(obj: Any) -> Any: + """Remove runtime-only objects from run config dictionaries. + + Timers and other runtime constructs are serialized with `_target_` entries + that cannot be recreated without additional context (e.g., constructor + arguments provided at runtime). These objects are not required when loading + a checkpoint configuration, so we replace them with ``None`` to avoid + instantiation errors when the config is processed later. + """ + + if isinstance(obj, dict): + target = obj.get("_target_") + if isinstance(target, str) and target in _RUNTIME_ONLY_TARGETS: + return None + return {key: _sanitize_run_config_object(value) for key, value in obj.items()} + if isinstance(obj, list): + return [_sanitize_run_config_object(item) for item in obj] + return obj diff --git a/src/megatron/bridge/utils/instantiate_utils.py b/src/megatron/bridge/utils/instantiate_utils.py index 75efde252e..83d7290f2c 100644 --- a/src/megatron/bridge/utils/instantiate_utils.py +++ b/src/megatron/bridge/utils/instantiate_utils.py @@ -17,6 +17,7 @@ import copy import functools +import inspect import logging from enum import Enum from textwrap import dedent @@ -236,22 +237,12 @@ def instantiate_node( if OmegaConf.is_missing(node, key) and is_partial: continue value = node[key] - try: - value = instantiate_node(value, mode=mode) - except (ImportError, InstantiationException) as e: - if mode == InstantiationMode.STRICT: - raise InstantiationException( - f"Error instantiating {value} for key {full_key}.{key}: {e}" - ) from e - else: - value = None - logging.warning( - f"Error instantiating {value} for key {full_key}.{key}. " - f"Using None instead in lenient mode." - ) + value = instantiate_node(value, mode=mode) kwargs[key] = _convert_node(value) assert callable(_target_) + # Drop unexpected kwargs in lenient mode or raise in strict mode + kwargs = _filter_kwargs_for_target(_target_, kwargs, full_key, mode) return _call_target(_target_, partial, args, kwargs, full_key) else: dict_items = {} @@ -356,6 +347,59 @@ def _convert_target_to_string(t: Any) -> Any: return t +def _filter_kwargs_for_target( + target: Callable[..., Any] | type, + kwargs: dict[str, Any], + full_key: str, + mode: InstantiationMode, +) -> dict[str, Any]: + """Drop unexpected keyword arguments for a target and warn. + + If the target accepts ``**kwargs`` we forward everything. Otherwise we + inspect the signature and remove keys not present as keyword-capable + parameters, emitting a warning with the dropped keys. + """ + try: + signature = inspect.signature(target) + except (TypeError, ValueError): + # Some builtins or C-extensions may not have an inspectable signature. + return kwargs + + parameters = signature.parameters + if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values()): + return kwargs + + allowed_keys = { + name + for name, param in parameters.items() + if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) + } + + unexpected = set(kwargs.keys()) - allowed_keys + if _Keys.ARGS in unexpected: + unexpected.remove(_Keys.ARGS) + + if not unexpected: + return kwargs + + target_str = _convert_target_to_string(target) + if mode == InstantiationMode.LENIENT: + # Warn and drop the unexpected keys + warning_msg = f"Dropping unexpected config keys for target '{target_str}': {sorted(unexpected)}" + if full_key: + warning_msg += f"\nfull_key: {full_key}" + logging.warning(warning_msg) + filtered = {k: v for k, v in kwargs.items() if k in allowed_keys} + if _Keys.ARGS in kwargs: + filtered[_Keys.ARGS] = kwargs[_Keys.ARGS] + return filtered + else: + msg = f"Unexpected config keys for target '{target_str}': {sorted(unexpected)}" + if full_key: + msg += f"\nfull_key: {full_key}" + raise InstantiationException(msg) + + def _prepare_input_dict_or_list(d: Union[dict[Any, Any], list[Any]]) -> Any: res: Any if isinstance(d, dict): diff --git a/tests/unit_tests/training/utils/test_checkpoint_utils.py b/tests/unit_tests/training/utils/test_checkpoint_utils.py index c200e49280..3c24f30d37 100644 --- a/tests/unit_tests/training/utils/test_checkpoint_utils.py +++ b/tests/unit_tests/training/utils/test_checkpoint_utils.py @@ -249,6 +249,31 @@ def test_read_run_config_invalid_yaml(self, mock_is_initialized, mock_get_rank): with pytest.raises(RuntimeError, match="Unable to load config file"): read_run_config("invalid.yaml") + @patch("megatron.bridge.training.utils.checkpoint_utils.get_rank_safe", return_value=0) + @patch("megatron.bridge.training.utils.checkpoint_utils.torch.distributed.is_initialized", return_value=False) + def test_read_run_config_sanitizes_runtime_only_targets(self, mock_is_initialized, mock_get_rank): + """Run config should drop runtime-only objects such as timers.""" + raw_config = { + "model": { + "timers": {"_target_": "megatron.core.timers.Timers"}, + "keep": {"_target_": "some.other.Component", "value": 1}, + "nested": [ + {"timers": {"_target_": "megatron.core.timers.Timers"}}, + {"other": {"_target_": "another.Component", "value": 2}}, + ], + }, + "tokenizer": {"type": "sentencepiece"}, + } + config_yaml = yaml.dump(raw_config) + + with patch("builtins.open", mock_open(read_data=config_yaml)): + result = read_run_config("config_with_timers.yaml") + + assert result["model"]["timers"] is None + assert result["model"]["nested"][0]["timers"] is None + assert result["model"]["keep"]["_target_"] == "some.other.Component" + assert result["model"]["nested"][1]["other"]["_target_"] == "another.Component" + @patch("megatron.bridge.training.utils.checkpoint_utils.get_rank_safe") @patch("megatron.bridge.training.utils.checkpoint_utils.torch.distributed.is_initialized") @patch("megatron.bridge.training.utils.checkpoint_utils.torch.load") diff --git a/tests/unit_tests/utils/test_instantiate_utils.py b/tests/unit_tests/utils/test_instantiate_utils.py index d4ad616147..af2ad22ae7 100644 --- a/tests/unit_tests/utils/test_instantiate_utils.py +++ b/tests/unit_tests/utils/test_instantiate_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import enum import functools import logging from unittest.mock import MagicMock, patch @@ -191,18 +192,14 @@ def test_instantiate_strict_mode_error(self): with pytest.raises(InstantiationException): instantiate(config, mode=InstantiationMode.STRICT) - def test_instantiate_lenient_mode_error(self, caplog): - """Test instantiate in lenient mode with error.""" + def test_instantiate_lenient_mode_error(self): + """In lenient mode, nested resolution errors now propagate (no auto-None).""" config = { "_target_": "tests.unit_tests.utils.test_instantiate_utils.TestClass", "nested": {"_target_": "non.existent.module.Class"}, } - with caplog.at_level(logging.WARNING): - result = instantiate(config, mode=InstantiationMode.LENIENT) - - assert isinstance(result, TestClass) - assert result.kwargs["nested"] is None - assert "Error instantiating" in caplog.text + with pytest.raises(InstantiationException, match="Error locating target"): + instantiate(config, mode=InstantiationMode.LENIENT) def test_instantiate_with_omegaconf_dict(self): """Test instantiate with OmegaConf DictConfig.""" @@ -525,3 +522,96 @@ def test_missing_values_with_partial(self): actual_result = result(arg2="value2") expected = {"arg1": "value1", "arg2": "value2", "kwargs": {}} assert actual_result == expected + + +class DummyTarget: + def __init__(self, a: int, b: int = 0) -> None: + self.a = a + self.b = b + + +class KwTarget: + def __init__(self, **kwargs) -> None: + self.kwargs = dict(kwargs) + + +def _target_qualname(obj) -> str: + return f"{obj.__module__}.{obj.__qualname__}" + + +def test_drops_unexpected_kwargs_and_warns(caplog: pytest.LogCaptureFixture) -> None: + config = { + "_target_": _target_qualname(DummyTarget), + "a": 10, + "foo": 123, # unexpected key that should be dropped + } + + with caplog.at_level(logging.WARNING): + obj = instantiate(config) + + assert isinstance(obj, DummyTarget) + assert obj.a == 10 + # 'foo' is dropped; 'b' remains default + assert obj.b == 0 + + # Ensure a warning was emitted mentioning the dropped key + warnings = [rec.getMessage() for rec in caplog.records if rec.levelno == logging.WARNING] + assert any("Dropping unexpected config keys" in m for m in warnings) + assert any("foo" in m for m in warnings) + + +def test_allows_kwargs_when_target_accepts_var_kwargs(caplog: pytest.LogCaptureFixture) -> None: + config = { + "_target_": _target_qualname(KwTarget), + "foo": 1, + "bar": 2, + } + + with caplog.at_level(logging.WARNING): + obj = instantiate(config) + + assert isinstance(obj, KwTarget) + assert obj.kwargs == {"foo": 1, "bar": 2} + + # No warning should be emitted for **kwargs targets + warnings = [rec.getMessage() for rec in caplog.records if rec.levelno == logging.WARNING] + assert not any("Dropping unexpected config keys" in m for m in warnings) + + +def test_raises_on_unexpected_kwargs_in_strict_mode() -> None: + config = { + "_target_": _target_qualname(DummyTarget), + "a": 10, + "foo": 123, + } + + with pytest.raises(InstantiationException): + instantiate(config, mode=InstantiationMode.STRICT) + + +class TestEnum(enum.Enum): + A = 1 + B = 2 + + +class TestInstantiateEnum: + """Test instantiation of Enums.""" + + def test_instantiate_enum_with_args(self): + """Test instantiating an Enum with _args_.""" + config = { + "_target_": "tests.unit_tests.utils.test_instantiate_utils.TestEnum", + "_args_": [1], + } + result = instantiate(config) + assert result == TestEnum.A + + def test_instantiate_enum_with_args_lenient(self): + """Test instantiating an Enum with _args_ in lenient mode (default).""" + config = { + "_target_": "tests.unit_tests.utils.test_instantiate_utils.TestEnum", + "_args_": [2], + } + # This previously failed because _args_ was dropped in lenient mode + result = instantiate(config) + assert result == TestEnum.B