|
1 | 1 | import importlib |
2 | 2 |
|
| 3 | +from .base_executor import Executor |
| 4 | + |
3 | 5 | _IMPORT_ERRORS: dict[str, str] = {} |
4 | 6 |
|
5 | 7 |
|
6 | | -def _safe_import(name: str, module: str) -> type | None: |
| 8 | +def _import_executor(name: str, module: str) -> type[Executor] | None: |
7 | 9 | try: |
8 | 10 | pkg = importlib.import_module(module, package=__package__) |
9 | | - return getattr(pkg, name) |
| 11 | + if issubclass(cls := getattr(pkg, name), Executor): |
| 12 | + return cls |
| 13 | + error = f"{name} is not a subclass of Executor" |
10 | 14 | except Exception as exc: |
11 | | - _IMPORT_ERRORS[name] = str(exc) |
12 | | - return None |
| 15 | + error = str(exc) |
| 16 | + _IMPORT_ERRORS[name] = error |
| 17 | + return None |
13 | 18 |
|
14 | 19 |
|
15 | | -VLLMExecutor = _safe_import("VLLMExecutor", ".vllm_executor") |
16 | | -VLLMLoRAExecutor = _safe_import("VLLMLoRAExecutor", ".vllm_lora_executor") |
17 | | -PPOExecutor = _safe_import("PPOExecutor", ".ppo_executor") |
18 | | -DPOExecutor = _safe_import("DPOExecutor", ".dpo_executor") |
19 | | -SFTExecutor = _safe_import("SFTExecutor", ".sft_executor") |
20 | | -LoRASFTExecutor = _safe_import("LoRASFTExecutor", ".lora_sft_executor") |
21 | | -ImageClassificationTrainingExecutor = _safe_import( |
| 20 | +VLLMExecutor = _import_executor("VLLMExecutor", ".vllm_executor") |
| 21 | +VLLMLoRAExecutor = _import_executor("VLLMLoRAExecutor", ".vllm_lora_executor") |
| 22 | +PPOExecutor = _import_executor("PPOExecutor", ".ppo_executor") |
| 23 | +DPOExecutor = _import_executor("DPOExecutor", ".dpo_executor") |
| 24 | +SFTExecutor = _import_executor("SFTExecutor", ".sft_executor") |
| 25 | +LoRASFTExecutor = _import_executor("LoRASFTExecutor", ".lora_sft_executor") |
| 26 | +ImageClassificationTrainingExecutor = _import_executor( |
22 | 27 | "ImageClassificationTrainingExecutor", ".image_classification_executor" |
23 | 28 | ) |
24 | | -HFTransformersExecutor = _safe_import( |
| 29 | +HFTransformersExecutor = _import_executor( |
25 | 30 | "HFTransformersExecutor", ".transformers_executor" |
26 | 31 | ) |
27 | | -RAGExecutor = _safe_import("RAGExecutor", ".rag_executor") |
28 | | -AgentExecutor = _safe_import("AgentExecutor", ".agent_executor") |
29 | | -EchoExecutor = _safe_import("EchoExecutor", ".echo_executor") |
30 | | -DataProfilingExecutor = _safe_import( |
| 32 | +RAGExecutor = _import_executor("RAGExecutor", ".rag_executor") |
| 33 | +AgentExecutor = _import_executor("AgentExecutor", ".agent_executor") |
| 34 | +EchoExecutor = _import_executor("EchoExecutor", ".echo_executor") |
| 35 | +DataProfilingExecutor = _import_executor( |
31 | 36 | "DataProfilingExecutor", ".data_profiling_executor" |
32 | 37 | ) |
33 | | -DataRetrievalExecutor = _safe_import( |
| 38 | +DataRetrievalExecutor = _import_executor( |
34 | 39 | "DataRetrievalExecutor", ".data_retrieval_executor" |
35 | 40 | ) |
36 | | -DiffusersExecutor = _safe_import("DiffusersExecutor", ".diffusers_executor") |
37 | | -APIExecutor = _safe_import("APIExecutor", ".api_executor") |
38 | | -SSHExecutor = _safe_import("SSHExecutor", ".ssh_executor") |
39 | | -OmniText2ImageExecutor = _safe_import( |
| 41 | +DiffusersExecutor = _import_executor("DiffusersExecutor", ".diffusers_executor") |
| 42 | +APIExecutor = _import_executor("APIExecutor", ".api_executor") |
| 43 | +SSHExecutor = _import_executor("SSHExecutor", ".ssh_executor") |
| 44 | +OmniText2ImageExecutor = _import_executor( |
40 | 45 | "OmniText2ImageExecutor", ".omni_text2image_executor" |
41 | 46 | ) |
42 | | -OmniText2SpeechExecutor = _safe_import( |
| 47 | +OmniText2SpeechExecutor = _import_executor( |
43 | 48 | "OmniText2SpeechExecutor", ".omni_text2speech_executor" |
44 | 49 | ) |
45 | | -OmniText2AudioExecutor = _safe_import( |
| 50 | +OmniText2AudioExecutor = _import_executor( |
46 | 51 | "OmniText2AudioExecutor", ".omni_text2audio_executor" |
47 | 52 | ) |
48 | | -OmniText2GeneralExecutor = _safe_import( |
| 53 | +OmniText2GeneralExecutor = _import_executor( |
49 | 54 | "OmniText2GeneralExecutor", ".omni_text2general_executor" |
50 | 55 | ) |
51 | 56 |
|
52 | | -EXECUTOR_REGISTRY: dict[str, type | None] = { |
| 57 | +EXECUTOR_REGISTRY: dict[str, type[Executor] | None] = { |
53 | 58 | "vllm": VLLMExecutor, |
54 | 59 | "vllm_lora": VLLMLoRAExecutor, |
55 | 60 | "ppo": PPOExecutor, |
|
0 commit comments