From 34d6fe353ac06c5d3cf519642b34c58cb59807d2 Mon Sep 17 00:00:00 2001 From: Bob Yang Date: Thu, 20 Mar 2025 11:00:07 -0700 Subject: [PATCH] Componenets to get fn returning Any (#1018) Summary: Pull Request resolved: https://github.com/pytorch/torchx/pull/1018 Allow `Any` so that we can have different function signatures for pipeline definition Reviewed By: lgarg26 Differential Revision: D70936712 --- torchx/specs/builders.py | 8 ++++++-- torchx/specs/finder.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/torchx/specs/builders.py b/torchx/specs/builders.py index 998e81076..ce7097d56 100644 --- a/torchx/specs/builders.py +++ b/torchx/specs/builders.py @@ -101,7 +101,7 @@ def _merge_config_values_with_args( def parse_args( - cmpnt_fn: Callable[..., AppDef], + cmpnt_fn: Callable[..., Any], # pyre-ignore[2] cmpnt_args: List[str], cmpnt_defaults: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None, @@ -130,7 +130,7 @@ def parse_args( def materialize_appdef( - cmpnt_fn: Callable[..., AppDef], + cmpnt_fn: Callable[..., Any], # pyre-ignore[2] cmpnt_args: List[str], cmpnt_defaults: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None, @@ -187,6 +187,10 @@ def materialize_appdef( var_arg = var_arg[1:] appdef = cmpnt_fn(*function_args, *var_arg, **kwargs) + if not isinstance(appdef, AppDef): + raise TypeError( + f"Expected a component that returns `AppDef`, but got `{type(appdef)}`" + ) return appdef diff --git a/torchx/specs/finder.py b/torchx/specs/finder.py index 980aa9473..ab1284a7b 100644 --- a/torchx/specs/finder.py +++ b/torchx/specs/finder.py @@ -16,7 +16,7 @@ from inspect import getmembers, isfunction from pathlib import Path from types import ModuleType -from typing import Callable, Dict, Generator, List, Optional, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Union from torchx.specs import AppDef from torchx.specs.file_linter import get_fn_docstring, TorchxFunctionValidator, validate @@ -24,6 +24,7 @@ from torchx.util.io import read_conf_file from torchx.util.types import none_throws + logger: logging.Logger = logging.getLogger(__name__) @@ -53,7 +54,10 @@ class _Component: name: str description: str fn_name: str - fn: Callable[..., AppDef] + + # pyre-ignore[4] TODO temporary until PipelineDef is decoupled and can be exposed as type to OSS + fn: Callable[..., Any] + validation_errors: List[str]