Skip to content

Commit ad8aaea

Browse files
authored
Revert "Remove the module class scan (#2790)" (#2819)
This reverts commit b96dc32.
1 parent d7d97fd commit ad8aaea

File tree

2 files changed

+7
-72
lines changed

2 files changed

+7
-72
lines changed

nvflare/fuel/utils/class_utils.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from typing import Dict, List, Optional
2020

2121
from nvflare.security.logging import secure_format_exception
22-
from nvflare.utils.components_utils import create_classes_table_static
2322

2423
DEPRECATED_PACKAGES = ["nvflare.app_common.pt", "nvflare.app_common.homomorphic_encryption"]
2524

@@ -86,10 +85,10 @@ def __init__(self, base_pkgs: List[str], module_names: List[str], exclude_libs=T
8685
self.exclude_libs = exclude_libs
8786

8887
self._logger = logging.getLogger(self.__class__.__name__)
89-
self._class_table = create_classes_table_static()
88+
self._class_table: Dict[str, str] = {}
89+
self._create_classes_table()
9090

9191
def _create_classes_table(self):
92-
class_table: Dict[str, str] = {}
9392
scan_result_table = {}
9493
for base in self.base_pkgs:
9594
package = importlib.import_module(base)
@@ -112,21 +111,20 @@ def _create_classes_table(self):
112111
# same class name exists in multiple modules
113112
if name in scan_result_table:
114113
scan_result = scan_result_table[name]
115-
if name in class_table:
116-
class_table.pop(name)
117-
class_table[f"{scan_result.module_name}.{name}"] = module_name
118-
class_table[f"{module_name}.{name}"] = module_name
114+
if name in self._class_table:
115+
self._class_table.pop(name)
116+
self._class_table[f"{scan_result.module_name}.{name}"] = module_name
117+
self._class_table[f"{module_name}.{name}"] = module_name
119118
else:
120119
scan_result = _ModuleScanResult(class_name=name, module_name=module_name)
121120
scan_result_table[name] = scan_result
122-
class_table[name] = module_name
121+
self._class_table[name] = module_name
123122
except (ModuleNotFoundError, RuntimeError) as e:
124123
self._logger.debug(
125124
f"Try to import module {module_name}, but failed: {secure_format_exception(e)}. "
126125
f"Can't use name in config to refer to classes in module: {module_name}."
127126
)
128127
pass
129-
return class_table
130128

131129
def get_module_name(self, class_name) -> Optional[str]:
132130
"""Gets the name of the module that contains this class.

nvflare/utils/components_utils.py

-63
This file was deleted.

0 commit comments

Comments
 (0)