Skip to content
64 changes: 51 additions & 13 deletions src/megatron/bridge/utils/instantiate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import copy
import functools
import inspect
import logging
from enum import Enum
from textwrap import dedent
Expand Down Expand Up @@ -236,22 +237,12 @@ def instantiate_node(
if OmegaConf.is_missing(node, key) and is_partial:
continue
value = node[key]
try:
value = instantiate_node(value, mode=mode)
except (ImportError, InstantiationException) as e:
if mode == InstantiationMode.STRICT:
raise InstantiationException(
f"Error instantiating {value} for key {full_key}.{key}: {e}"
) from e
else:
value = None
logging.warning(
f"Error instantiating {value} for key {full_key}.{key}. "
f"Using None instead in lenient mode."
)
value = instantiate_node(value, mode=mode)
kwargs[key] = _convert_node(value)

assert callable(_target_)
# Drop unexpected kwargs in lenient mode or raise in strict mode
kwargs = _filter_kwargs_for_target(_target_, kwargs, full_key, mode)
return _call_target(_target_, partial, args, kwargs, full_key)
else:
dict_items = {}
Expand Down Expand Up @@ -356,6 +347,53 @@ def _convert_target_to_string(t: Any) -> Any:
return t


def _filter_kwargs_for_target(
target: Callable[..., Any] | type,
kwargs: dict[str, Any],
full_key: str,
mode: InstantiationMode,
) -> dict[str, Any]:
"""Drop unexpected keyword arguments for a target and warn.

If the target accepts ``**kwargs`` we forward everything. Otherwise we
inspect the signature and remove keys not present as keyword-capable
parameters, emitting a warning with the dropped keys.
"""
try:
signature = inspect.signature(target)
except (TypeError, ValueError):
# Some builtins or C-extensions may not have an inspectable signature.
return kwargs

parameters = signature.parameters
if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values()):
return kwargs

allowed_keys = {
name
for name, param in parameters.items()
if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
}

unexpected = set(kwargs.keys()) - allowed_keys
if not unexpected:
return kwargs

target_str = _convert_target_to_string(target)
if mode == InstantiationMode.LENIENT:
# Warn and drop the unexpected keys
warning_msg = f"Dropping unexpected config keys for target '{target_str}': {sorted(unexpected)}"
if full_key:
warning_msg += f"\nfull_key: {full_key}"
logging.warning(warning_msg)
return {k: v for k, v in kwargs.items() if k in allowed_keys}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this avoids erroring out when fields are removed from the config. but in the case of changes like this, the semantics are lost: NVIDIA/Megatron-LM#1917

for example:

  • a user previously used set external_cuda_graph=True and the config did not yet have cuda_graph_scope and has checkpoints saved with this config
  • in a newer version where external_cuda_graph is removed, this PR will drop the arg in lenient mode.

it might not always be possible to infer the new setting from an old one so this might be inevitable, but we should make no claims about full reproducibility across versions in this case

else:
msg = f"Unexpected config keys for target '{target_str}': {sorted(unexpected)}"
if full_key:
msg += f"\nfull_key: {full_key}"
raise InstantiationException(msg)


def _prepare_input_dict_or_list(d: Union[dict[Any, Any], list[Any]]) -> Any:
res: Any
if isinstance(d, dict):
Expand Down
77 changes: 69 additions & 8 deletions tests/unit_tests/utils/test_instantiate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,18 +191,14 @@ def test_instantiate_strict_mode_error(self):
with pytest.raises(InstantiationException):
instantiate(config, mode=InstantiationMode.STRICT)

def test_instantiate_lenient_mode_error(self, caplog):
"""Test instantiate in lenient mode with error."""
def test_instantiate_lenient_mode_error(self):
"""In lenient mode, nested resolution errors now propagate (no auto-None)."""
config = {
"_target_": "tests.unit_tests.utils.test_instantiate_utils.TestClass",
"nested": {"_target_": "non.existent.module.Class"},
}
with caplog.at_level(logging.WARNING):
result = instantiate(config, mode=InstantiationMode.LENIENT)

assert isinstance(result, TestClass)
assert result.kwargs["nested"] is None
assert "Error instantiating" in caplog.text
with pytest.raises(InstantiationException, match="Error locating target"):
instantiate(config, mode=InstantiationMode.LENIENT)

def test_instantiate_with_omegaconf_dict(self):
"""Test instantiate with OmegaConf DictConfig."""
Expand Down Expand Up @@ -525,3 +521,68 @@ def test_missing_values_with_partial(self):
actual_result = result(arg2="value2")
expected = {"arg1": "value1", "arg2": "value2", "kwargs": {}}
assert actual_result == expected


class DummyTarget:
def __init__(self, a: int, b: int = 0) -> None:
self.a = a
self.b = b


class KwTarget:
def __init__(self, **kwargs) -> None:
self.kwargs = dict(kwargs)


def _target_qualname(obj) -> str:
return f"{obj.__module__}.{obj.__qualname__}"


def test_drops_unexpected_kwargs_and_warns(caplog: pytest.LogCaptureFixture) -> None:
config = {
"_target_": _target_qualname(DummyTarget),
"a": 10,
"foo": 123, # unexpected key that should be dropped
}

with caplog.at_level(logging.WARNING):
obj = instantiate(config)

assert isinstance(obj, DummyTarget)
assert obj.a == 10
# 'foo' is dropped; 'b' remains default
assert obj.b == 0

# Ensure a warning was emitted mentioning the dropped key
warnings = [rec.getMessage() for rec in caplog.records if rec.levelno == logging.WARNING]
assert any("Dropping unexpected config keys" in m for m in warnings)
assert any("foo" in m for m in warnings)


def test_allows_kwargs_when_target_accepts_var_kwargs(caplog: pytest.LogCaptureFixture) -> None:
config = {
"_target_": _target_qualname(KwTarget),
"foo": 1,
"bar": 2,
}

with caplog.at_level(logging.WARNING):
obj = instantiate(config)

assert isinstance(obj, KwTarget)
assert obj.kwargs == {"foo": 1, "bar": 2}

# No warning should be emitted for **kwargs targets
warnings = [rec.getMessage() for rec in caplog.records if rec.levelno == logging.WARNING]
assert not any("Dropping unexpected config keys" in m for m in warnings)


def test_raises_on_unexpected_kwargs_in_strict_mode() -> None:
config = {
"_target_": _target_qualname(DummyTarget),
"a": 10,
"foo": 123,
}

with pytest.raises(InstantiationException):
instantiate(config, mode=InstantiationMode.STRICT)
Loading