Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix Adafactor optim on torch2.5 and fix compatibility #1600

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mmengine/optim/optimizer/builder.py
Original file line number Diff line number Diff line change
@@ -25,7 +25,11 @@ def register_torch_optimizers() -> List[str]:
_optim = getattr(torch.optim, module_name)
if inspect.isclass(_optim) and issubclass(_optim,
torch.optim.Optimizer):
OPTIMIZERS.register_module(module=_optim)
if module_name == 'Adafactor':
OPTIMIZERS.register_module(
name='torch_Adafactor', module=_optim)
else:
OPTIMIZERS.register_module(module=_optim)
torch_optimizers.append(module_name)
return torch_optimizers

16 changes: 15 additions & 1 deletion mmengine/registry/build_functions.py
Original file line number Diff line number Diff line change
@@ -3,8 +3,10 @@
import logging
from typing import TYPE_CHECKING, Any, Optional, Union

import torch

from mmengine.config import Config, ConfigDict
from mmengine.utils import ManagerMixin
from mmengine.utils import ManagerMixin, digit_version
from .registry import Registry

if TYPE_CHECKING:
@@ -232,6 +234,18 @@ def build_model_from_cfg(
return build_from_cfg(cfg, registry, default_args)


def build_optimizer_from_cfg(
cfg: Union[dict, ConfigDict, Config],
registry: Registry,
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any:
if 'Adafactor' == cfg['type'] and digit_version(
torch.__version__) >= digit_version('2.5.0'):
from ..logging import print_log
print_log(
'the torch version of Adafactor is registered as torch_Adafactor')
return build_from_cfg(cfg, registry, default_args)


def build_scheduler_from_cfg(
cfg: Union[dict, ConfigDict, Config],
registry: Registry,
6 changes: 3 additions & 3 deletions mmengine/registry/root.py
Original file line number Diff line number Diff line change
@@ -6,8 +6,8 @@
https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html.
"""

from .build_functions import (build_model_from_cfg, build_runner_from_cfg,
build_scheduler_from_cfg)
from .build_functions import (build_model_from_cfg, build_optimizer_from_cfg,
build_runner_from_cfg, build_scheduler_from_cfg)
from .registry import Registry

# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
@@ -35,7 +35,7 @@
WEIGHT_INITIALIZERS = Registry('weight initializer')

# mangage all kinds of optimizers like `SGD` and `Adam`
OPTIMIZERS = Registry('optimizer')
OPTIMIZERS = Registry('optimizer', build_func=build_optimizer_from_cfg)
# manage optimizer wrapper
OPTIM_WRAPPERS = Registry('optim_wrapper')
# manage constructors that customize the optimization hyperparameters.
21 changes: 18 additions & 3 deletions mmengine/testing/_internal/distributed.py
Original file line number Diff line number Diff line change
@@ -92,10 +92,25 @@ def wrapper(self):
# Constructor patches current instance test method to
# assume the role of the main process and join its subprocesses,
# or run the underlying test function.
def __init__(self, method_name: str = 'runTest') -> None:
def __init__(self,
method_name: str = 'runTest',
methodName: str = 'runTest') -> None:
# methodName is the correct naming in unittest
# and testslide uses keyword arguments.
# So we need to use both to 1) not break BC and, 2) support testslide.
if methodName != 'runTest':
method_name = methodName
super().__init__(method_name)
fn = getattr(self, method_name)
setattr(self, method_name, self.join_or_run(fn))
try:
fn = getattr(self, method_name)
setattr(self, method_name, self.join_or_run(fn))
except AttributeError as e:
if methodName != 'runTest':
# we allow instantiation with no explicit method name
# but not an *incorrect* or missing method name
raise ValueError(
f'no such test method in {self.__class__}: {methodName}'
) from e

def setUp(self) -> None:
super().setUp()