diff --git a/README.md b/README.md index a4541fb6..14a9095c 100644 --- a/README.md +++ b/README.md @@ -207,21 +207,22 @@ The `sktime-mcp` server exposes a rich suite of tools categorized logically. Eve ### Category 1: Discovery & Registry Tools -These tools enable the LLM to inspect the native `sktime` registry, search for estimators matching specific criteria, and understand their capability profiles. +These tools enable the LLM to inspect the native `sktime` registry, search for estimators, capability tags, or performance metrics, and understand their capability profiles. -#### 1. `list_estimators` -Discover estimator classes from the `sktime` registry by task type, capabilities, or name query. +#### 1. `query_registry` +Unified entry point to discover and search for estimators, capability tags, or performance metrics in the `sktime` ecosystem. * **Arguments:** - * `task` (`str`, optional): Filter by task type. Valid values: `"forecasting"`, `"classification"`, `"regression"`, `"transformation"`, `"clustering"`, `"detection"`. - * `tags` (`dict[str, Any]`, optional): Key-value pairs filtering by capability tags (e.g., `{"capability:pred_int": true, "handles-missing-data": true}`). - * `query` (`str`, optional): Case-insensitive substring search on the estimator's class name or docstring. + * `target` (`str`, required): What registry target to search. Valid values: `"estimators"`, `"tags"`, `"metrics"`. + * `task` (`str`, optional): Filter estimators or metrics by task type (e.g. `"forecasting"`, `"classification"`, `"regression"`, `"transformation"`, `"clustering"`, `"detection"`, `"metric"`). + * `tags` (`dict[str, Any]`, optional): Key-value pairs filtering estimators by capability tags. Values must match each tag's expected type (bool, str, list, etc.; use `target="tags"` to inspect `value_type`). Example: `{"capability:pred_int": true, "scitype:y": "univariate"}`. + * `query` (`str`, optional): Case-insensitive substring search on names or descriptions. * `limit` (`int`, optional, default=`50`): Maximum number of results to return. * `offset` (`int`, optional, default=`0`): Number of entries to skip for pagination. * **Returns:** ```json { "success": true, - "estimators": [ + "results": [ { "name": "ARIMA", "task": "forecasting", @@ -233,48 +234,32 @@ Discover estimator classes from the `sktime` registry by task type, capabilities "import_path": "sktime.forecasting.arima.ARIMA" } ], + "count": 1, "total": 1 } ``` -#### 2. `describe_estimator` -Get full documentation, hyperparameters, capability tags, and Python import path for a specific named estimator class. +#### 2. `describe_component` +Get full documentation, constructor parameters, capability tags, and Python import path for ANY specific class or component in the `sktime` ecosystem (estimators, transformers, splitters, metrics, aligners). * **Arguments:** - * `estimator` (`str`, required): Name of the estimator class (e.g., `"ARIMA"`, `"NaiveForecaster"`). + * `name` (`str`, required): Name of the component class (e.g., `"ARIMA"`, `"SlidingWindowSplitter"`, `"MeanAbsolutePercentageError"`). * **Returns:** ```json { "success": true, "name": "ARIMA", - "description": "Autoregressive Integrated Moving Average...", - "params": { - "order": { - "type": "tuple", - "default": [1, 0, 0], - "description": "The (p,d,q) order of the model." - } + "task": "forecasting", + "module": "sktime.forecasting.arima", + "parameters": { + "order": {"default": [1, 0, 0], "required": false} }, "tags": { "capability:pred_int": true }, - "import_path": "sktime.forecasting.arima.ARIMA" - } - ``` - -#### 3. `get_available_tags` -List all queryable capability tags across the `sktime` registry with their descriptions and expected value types. -* **Arguments:** None. -* **Returns:** - ```json - { - "success": true, - "tags": { - "capability:pred_int": { - "type": "bool", - "description": "Whether the forecaster can compute prediction intervals.", - "scitype": "forecaster" - } - } + "tag_explanations": { + "capability:pred_int": "Whether the forecaster can compute prediction intervals." + }, + "docstring": "Autoregressive Integrated Moving Average..." } ``` @@ -945,8 +930,9 @@ Below are examples demonstrating how an LLM utilizes these redesigned tools to c 1. **Discover Models** The LLM queries for forecasting models supporting prediction intervals. ```json - // list_estimators + // query_registry { + "target": "estimators", "task": "forecasting", "tags": { "capability:pred_int": true @@ -958,9 +944,9 @@ Below are examples demonstrating how an LLM utilizes these redesigned tools to c 2. **Inspect Choice** The LLM inspects the parameter schema of `"ARIMA"`. ```json - // describe_estimator + // describe_component { - "estimator": "ARIMA" + "name": "ARIMA" } ``` diff --git a/docs/source/implementation.md b/docs/source/implementation.md index af433493..d2531df1 100644 --- a/docs/source/implementation.md +++ b/docs/source/implementation.md @@ -81,7 +81,7 @@ The codebase is organized into **5 main layers**: 2. **`@server.list_tools()`**: Registers all available MCP tools - Returns tool schemas (name, description, input schema) - - Tools span Discovery, Instantiation, Execution, Data, Export, Persistence, Validation, and Job Management. (e.g., `list_estimators`, `instantiate_pipeline`, `fit_predict_async`, `load_data_source`, `save_model`, `check_job_status`). + - Tools span Discovery, Instantiation, Execution, Data, Export, Persistence, Validation, and Job Management. (e.g., `query_registry`, `instantiate_pipeline`, `fit_predict_async`, `load_data_source`, `save_model`, `check_job_status`). 3. **`@server.call_tool(name, arguments)`**: Routes tool calls to implementations - Validates arguments @@ -306,19 +306,14 @@ Each file implements one or more MCP tools that LLMs can call. #### `list_estimators.py` **Tools**: -1. **`list_estimators_tool(task, tags, query, limit)`** - - Calls `registry.get_all_estimators(task, tags)` - - Returns: `{"success": True, "estimators": [...], "total": 50}` +1. **`query_registry_tool(target, task, tags, query, limit, offset)`** + - Queries the unified registry for estimators, capability tags, or performance metrics. + - Returns query results. -2. **`get_available_tags()`** - - Returns all capability tags - - Example: `["capability:pred_int", "handles-missing-data", ...]` - -#### `describe_estimator.py` -**Tool**: `describe_estimator_tool(estimator)` -- Looks up estimator in registry -- Returns full EstimatorNode details -- Includes: name, task, module, tags, hyperparameters, docstring +#### `describe_component.py` +**Tool**: `describe_component_tool(name)` +- Looks up a component (estimator, splitter, metric, transformer) in the registry by name +- Returns detailed component information #### `instantiate.py` **Tools**: @@ -382,14 +377,14 @@ Each file implements one or more MCP tools that LLMs can call. **Steps**: 1. List datasets -2. Discover forecasting estimators +2. Discover forecasting estimators using `query_registry` 3. Filter by tags (probabilistic forecasters) -4. Describe an estimator +4. Describe a component using `describe_component` 5. Validate pipeline compositions 6. Instantiate estimator 7. Fit and predict 8. List active handles -9. Show available tags +9. Show available tags using `query_registry` **Run**: `python examples/01_forecasting_workflow.py` @@ -399,8 +394,8 @@ Each file implements one or more MCP tools that LLMs can call. **Scenario**: User asks "Forecast airline passengers with a probabilistic model" **LLM Steps**: -1. `list_estimators(task="forecasting", tags={"capability:pred_int": True})` -2. `describe_estimator("ARIMA")` +1. `query_registry(target="estimators", task="forecasting", tags={"capability:pred_int": True})` +2. `describe_component("ARIMA")` 3. `instantiate_estimator("ARIMA", {"order": [1,1,1]})` 4. `fit_predict(handle, "airline", 12)` @@ -478,19 +473,19 @@ Each file implements one or more MCP tools that LLMs can call. **Step 1: Discovery** ``` -LLM → list_estimators(task="forecasting") - → server.call_tool("list_estimators", {"task": "forecasting"}) - → list_estimators_tool(task="forecasting") +LLM → query_registry(target="estimators", task="forecasting") + → server.call_tool("query_registry", {"target": "estimators", "task": "forecasting"}) + → query_registry_tool(target="estimators", task="forecasting") → registry.get_all_estimators(task="forecasting") → Returns: [{"name": "ARIMA", ...}, {"name": "NaiveForecaster", ...}, ...] ``` **Step 2: Description** ``` -LLM → describe_estimator("ARIMA") - → describe_estimator_tool("ARIMA") +LLM → describe_component("ARIMA") + → describe_component_tool("ARIMA") → registry.get_estimator_by_name("ARIMA") - → Returns: {"name": "ARIMA", "hyperparameters": {"order": ...}, ...} + → Returns: {"name": "ARIMA", "parameters": {"order": ...}, ...} ``` **Step 3: Instantiation** diff --git a/docs/source/usage-examples.md b/docs/source/usage-examples.md index cbb0679d..347d32f0 100644 --- a/docs/source/usage-examples.md +++ b/docs/source/usage-examples.md @@ -43,8 +43,9 @@ Examples of tasks performed via sktime-mcp and their corresponding MCP tool call **1. Discover models with prediction interval support:** ```json { - "name": "list_estimators", + "name": "query_registry", "arguments": { + "target": "estimators", "task": "forecasting", "tags": {"capability:pred_int": true} } @@ -54,8 +55,8 @@ Examples of tasks performed via sktime-mcp and their corresponding MCP tool call **2. Inspect the chosen estimator:** ```json { - "name": "describe_estimator", - "arguments": {"estimator": "ARIMA"} + "name": "describe_component", + "arguments": {"name": "ARIMA"} } ``` diff --git a/src/sktime_mcp/server.py b/src/sktime_mcp/server.py index 8951258f..d83d2a18 100644 --- a/src/sktime_mcp/server.py +++ b/src/sktime_mcp/server.py @@ -45,7 +45,9 @@ load_data_source_tool, release_data_handle_tool, ) -from sktime_mcp.tools.describe_estimator import describe_estimator_tool +from sktime_mcp.tools.describe_component import ( + describe_component_tool, +) from sktime_mcp.tools.evaluate import evaluate_estimator_tool from sktime_mcp.tools.fit_predict import ( fit_predict_async_tool, @@ -66,8 +68,7 @@ ) from sktime_mcp.tools.list_available_data import list_available_data_tool from sktime_mcp.tools.list_estimators import ( - get_available_tags, - list_estimators_tool, + query_registry_tool, ) from sktime_mcp.tools.save_model import save_model_tool @@ -205,72 +206,94 @@ async def list_tools() -> list[Tool]: return [ # -- Discovery ------------------------------------------------------- Tool( - name="list_estimators", + name="query_registry", description=( - "Discover sktime estimators by task, capability tags, or name search. " - "Common tags you can filter by: " - "'capability:pred_int' (bool) - prediction intervals, " - "'capability:multivariate' (bool) - multivariate support, " - "'handles-missing-data' (bool) - NaN handling, " - "'scitype:y' (str) - target type ('univariate'/'multivariate'/'both'), " - "'requires-fh-in-fit' (bool) - needs forecast horizon at fit time. " - "Use get_available_tags for the full catalog." + "Unified entry point to search the sktime component registry. " + "target='estimators' (default): discover forecasters, classifiers, transformers, " + "splitters, detectors, and other components; filter by task, capability tags, " + "or name/module/docstring substring; results include name, task, module, and tags. " + "target='tags': list all capability tags with descriptions, expected value_type " + "(bool, str, list, etc.), and which component types each tag applies to — " + "call this before filtering estimators by tags. " + "target='metrics': list performance metrics (task='metric'). " + "Supports pagination via limit (default 50) and offset (default 0)." ), inputSchema={ "type": "object", "properties": { + "target": { + "type": "string", + "description": ( + "Registry section to query: 'estimators' (default), 'tags', or 'metrics'. " + "'estimators' returns component summaries; 'tags' returns tag metadata; " + "'metrics' returns performance metric classes." + ), + "enum": ["estimators", "tags", "metrics"], + "default": "estimators", + }, "task": { "type": "string", - "description": "Task type filter: forecasting, classification, regression, transformation, clustering, detection", + "description": ( + "Filter estimators or metrics by task type. Valid values include " + "forecasting, classification, regression, transformation, clustering, " + "splitting, detection, alignment, parameter_estimation, network, and metric. " + "Only applies when target is 'estimators' or 'metrics'." + ), }, "tags": { "type": "object", - "description": "Filter by capability tags, e.g. {'capability:pred_int': true}", + "description": ( + "Filter estimators by capability tags (only when target='estimators'). " + "Values must match each tag's expected type — boolean, string, or list. " + "Use target='tags' first to see valid tag names and value_type for each. " + "Example: {'capability:pred_int': true, 'scitype:y': 'univariate'}" + ), }, "query": { "type": "string", "description": ( - "Search by name or description (substring, case-insensitive). " - "Can be combined with task and tags filters." + "Case-insensitive substring search over component names, modules, and " + "docstrings (target='estimators' or 'metrics'), or tag names and " + "descriptions (target='tags')." ), }, "limit": { "type": "integer", - "description": "Maximum results (default: 50)", + "description": "Maximum results per page (default: 50). Must be a positive integer.", "default": 50, }, "offset": { "type": "integer", - "description": "Skip this many results for pagination (default: 0)", + "description": "Number of results to skip for pagination (default: 0).", "default": 0, }, }, }, ), Tool( - name="describe_estimator", - description="Get detailed information about a specific sktime estimator", + name="describe_component", + description=( + "Get detailed metadata for a single sktime component class — estimators, " + "transformers, splitters, metrics, aligners, detectors, etc. " + "Returns constructor parameters (default values and required flags), " + "capability tags (values may be bool, str, list, or null), " + "human-readable tag explanations, Python module import path, and a " + "docstring preview. Use query_registry to discover valid class names first." + ), inputSchema={ "type": "object", "properties": { - "estimator": { + "name": { "type": "string", - "description": "Name of the estimator (e.g., 'ARIMA', 'RandomForest')", + "description": ( + "Component class name (case-insensitive), e.g. 'ARIMA', " + "'SlidingWindowSplitter', 'MeanAbsolutePercentageError'." + ), }, }, - "required": ["estimator"], + "required": ["name"], }, ), - Tool( - name="get_available_tags", - description=( - "List all queryable capability tags with rich metadata. " - "Returns tag name, description, expected value type, and which " - "estimator types the tag applies to. Call this before " - "using tags in list_estimators to ensure correct tag names and values." - ), - inputSchema={"type": "object", "properties": {}}, - ), # -- Instantiation --------------------------------------------------- Tool( name="instantiate_estimator", @@ -747,8 +770,9 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: try: # -- Discovery ------------------------------------------------------- - if name == "list_estimators": - result = list_estimators_tool( + if name == "query_registry": + result = query_registry_tool( + target=arguments.get("target", "estimators"), task=arguments.get("task"), tags=arguments.get("tags"), query=arguments.get("query"), @@ -756,19 +780,8 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: offset=arguments.get("offset", 0), ) - elif name == "search_estimators": - # Deprecated — kept for backward compatibility, routes to unified list_estimators - logger.warning("search_estimators is deprecated; use list_estimators(query=...)") - result = list_estimators_tool( - query=arguments["query"], - limit=arguments.get("limit", 20), - ) - - elif name == "describe_estimator": - result = describe_estimator_tool(arguments["estimator"]) - - elif name == "get_available_tags": - result = get_available_tags() + elif name == "describe_component": + result = describe_component_tool(name=arguments["name"]) # -- Instantiation --------------------------------------------------- elif name == "instantiate_estimator": diff --git a/src/sktime_mcp/tools/__init__.py b/src/sktime_mcp/tools/__init__.py index d8d0f0d7..61a421c3 100644 --- a/src/sktime_mcp/tools/__init__.py +++ b/src/sktime_mcp/tools/__init__.py @@ -6,7 +6,9 @@ load_data_source_tool, release_data_handle_tool, ) -from sktime_mcp.tools.describe_estimator import describe_estimator_tool +from sktime_mcp.tools.describe_component import ( + describe_component_tool, +) from sktime_mcp.tools.evaluate import evaluate_estimator_tool from sktime_mcp.tools.fit_predict import ( fit_predict_async_tool, @@ -27,15 +29,13 @@ ) from sktime_mcp.tools.list_available_data import list_available_data_tool from sktime_mcp.tools.list_estimators import ( - get_available_tags, - list_estimators_tool, + query_registry_tool, ) from sktime_mcp.tools.save_model import save_model_tool __all__ = [ - "list_estimators_tool", - "get_available_tags", - "describe_estimator_tool", + "describe_component_tool", + "query_registry_tool", "instantiate_estimator_tool", "instantiate_pipeline_tool", "list_handles_tool", diff --git a/src/sktime_mcp/tools/describe_component.py b/src/sktime_mcp/tools/describe_component.py new file mode 100644 index 00000000..6b3b65ec --- /dev/null +++ b/src/sktime_mcp/tools/describe_component.py @@ -0,0 +1,80 @@ +""" +describe_component tool for sktime MCP. + +Gets detailed metadata about any sktime component class. +""" + +from typing import Any + +from sktime_mcp.registry.interface import get_registry +from sktime_mcp.registry.tag_resolver import get_tag_resolver + + +def describe_component_tool(name: str) -> dict[str, Any]: + """Get detailed information about any class or component in the sktime ecosystem. + + This includes estimators, transformers, splitters, metrics, and aligners. + + Parameters + ---------- + name : str + Name of the component class (e.g., "ARIMA", "SlidingWindowSplitter", + "MeanAbsolutePercentageError"). Case-insensitive. + + Returns + ------- + dict + A dictionary containing detailed component information: + - "success" : bool + True if the component was found, False otherwise. + - "name" : str + The formal name of the component class. + - "task" : str + The task type of the component (e.g., "forecasting", "splitting", "metric"). + - "module" : str + The full import path of the module containing the component. + - "parameters" : dict + A dictionary mapping constructor parameter names to metadata dicts + with ``default`` and ``required`` keys. + - "tags" : dict + A dictionary mapping capability tag names to their values. + Values may be bool, str, list, or null depending on the tag. + - "tag_explanations" : dict + A dictionary mapping capability tags to human-readable explanations. + - "docstring" : str + A preview of the component's docstring (first 500 characters). + - "error" : str, optional + Error message if "success" is False. + """ + registry = get_registry() + tag_resolver = get_tag_resolver() + + node = registry.get_estimator_by_name(name) + if node is None: + # Try case-insensitive search + all_estimators = registry.get_all_estimators() + matches = [e for e in all_estimators if e.name.lower() == name.lower()] + if matches: + node = matches[0] + else: + return { + "success": False, + "error": f"Unknown component class: {name}", + "suggestion": "Use query_registry to discover available component classes", + } + + # Get tag explanations + tag_explanations = tag_resolver.explain_tags(node.tags) + + doc = node.docstring or "No description available." + + return { + "success": True, + "name": node.name, + "task": node.task, + "module": node.module, + "parameters": node.hyperparameters, + "tags": node.tags, + "tag_explanations": tag_explanations, + "docstring": doc[:500], + } diff --git a/src/sktime_mcp/tools/describe_estimator.py b/src/sktime_mcp/tools/describe_estimator.py deleted file mode 100644 index 1415ee4b..00000000 --- a/src/sktime_mcp/tools/describe_estimator.py +++ /dev/null @@ -1,103 +0,0 @@ -""" -describe_estimator tool for sktime MCP. -Gets detailed information about an estimator's capabilities. -""" - -from typing import Any - -from sktime_mcp.registry.interface import get_registry -from sktime_mcp.registry.tag_resolver import get_tag_resolver - - -def describe_estimator_tool(estimator: str) -> dict[str, Any]: - """ - Get detailed information about a specific estimator. - - Args: - estimator: Name of the estimator class (e.g., "ARIMA", "RandomForest") - - Returns: - Dictionary with: - - success: bool - - name: Estimator name - - task: Task type - - module: Full module path - - hyperparameters: Dict of parameter names with defaults - - tags: Dict of capability tags - - tag_explanations: Human-readable tag descriptions - - docstring: First 500 chars of docstring - - Example: - >>> describe_estimator_tool("ARIMA") - { - "success": True, - "name": "ARIMA", - "task": "forecasting", - "hyperparameters": {"order": {"default": [1,0,0], "required": False}, ...}, - "tags": {"capability:pred_int": True, ...}, - ... - } - """ - registry = get_registry() - tag_resolver = get_tag_resolver() - - node = registry.get_estimator_by_name(estimator) - if node is None: - # Try case-insensitive search - all_estimators = registry.get_all_estimators() - matches = [e for e in all_estimators if e.name.lower() == estimator.lower()] - if matches: - node = matches[0] - else: - return { - "success": False, - "error": f"Unknown estimator: {estimator}", - "suggestion": "Use list_estimators to discover available estimators", - } - - # Get tag explanations - tag_explanations = tag_resolver.explain_tags(node.tags) - - doc = node.docstring or "No description available." - - return { - "success": True, - "name": node.name, - "task": node.task, - "module": node.module, - "hyperparameters": node.hyperparameters, - "tags": node.tags, - "tag_explanations": tag_explanations, - "docstring": doc[:500], - } - - -def search_estimators_tool(query: str, limit: int = 20) -> dict[str, Any]: - """ - Search estimators by name or description. - - Args: - query: Search string (case-insensitive) - limit: Maximum results - - Returns: - Dictionary with matching estimators - """ - if limit < 1: - return { - "success": False, - "error": "limit must be a positive integer.", - } - - registry = get_registry() - - try: - matches = registry.search_estimators(query)[:limit] - return { - "success": True, - "query": query, - "results": [est.to_summary() for est in matches], - "count": len(matches), - } - except Exception as e: - return {"success": False, "error": str(e)} diff --git a/src/sktime_mcp/tools/list_estimators.py b/src/sktime_mcp/tools/list_estimators.py index e10355c8..c365231f 100644 --- a/src/sktime_mcp/tools/list_estimators.py +++ b/src/sktime_mcp/tools/list_estimators.py @@ -9,40 +9,113 @@ from sktime_mcp.registry.interface import get_registry -def list_estimators_tool( +def query_registry_tool( + target: str = "estimators", task: str | None = None, tags: dict[str, Any] | None = None, query: str | None = None, limit: int = 50, offset: int = 0, ) -> dict[str, Any]: - """ - Discover sktime estimators by task type, capability tags, and/or name search. - - All filters are combined: query narrows by name/docstring, then task and tags - are applied on top. - - Args: - task: Filter by task type. Options: "forecasting", "classification", - "regression", "transformation", "clustering", "detection" - tags: Filter by capability tags. Example: {"capability:pred_int": True} - query: Search by name or description (substring, case-insensitive). - limit: Maximum number of results to return (default: 50) - offset: Number of results to skip for pagination (default: 0). - - Returns: - Dictionary with: - - success: bool - - estimators: List of estimator summaries - - count: Number of results returned in this page - - total: Total matching estimators (before limit/offset) - - offset: Current offset (for pagination) - - limit: Current limit (for pagination) - - has_more: True if more results exist beyond this page + """Query the sktime registry for estimators, capability tags, or performance metrics. + + Parameters + ---------- + target : str, default="estimators" + The registry target to search. Must be one of the following: + - "estimators" : search for sktime estimators/components (e.g. forecasters, classifiers) + - "tags" : list/filter available capability tags + - "metrics" : search for performance metrics + task : str or None, default=None + Filter estimators or metrics by task type. Valid values include + "forecasting", "transformation", "classification", "regression", + "clustering", "splitting", "detection", "alignment", "parameter_estimation", + "network", and "metric". Only applies when `target` is "estimators" or "metrics". + tags : dict or None, default=None + Key-value pairs of capability tag filters. Values must match each tag's + expected type (bool, str, list, etc.; use ``target="tags"`` to inspect + ``value_type``). Example: ``{"capability:pred_int": True, "scitype:y": "univariate"}``. + Only applies when `target` is "estimators". + query : str or None, default=None + Substring search query over component names, modules, or docstrings. + Only applies when `target` is "estimators" or "tags". + limit : int, default=50 + Maximum number of results to return. Must be a positive integer. + offset : int, default=0 + Number of results to skip for pagination. Must be a non-negative integer. + + Returns + ------- + dict + A dictionary with the query results and metadata containing: + - "success" : bool + True if the query completed successfully, False otherwise. + - "results" : list of dict + List of matching components, metrics, or tags in summary form. + - "count" : int + Number of results in the current page. + - "total" : int + Total number of matching results across all pages. + - "offset" : int + The offset used for pagination. + - "limit" : int + The limit used for pagination. + - "has_more" : bool + True if there are more results to fetch, False otherwise. + - "target" : str + The target that was queried. + - "task_filter" : str or None + The task filter that was applied. + - "tag_filter" : dict or None + The tag filter that was applied. + - "query" : str or None + The search query that was applied. + - "error" : str, optional + Error message if "success" is False. """ registry = get_registry() try: - # Validate task + # Validate target + valid_targets = ["estimators", "tags", "metrics"] + if target not in valid_targets: + return { + "success": False, + "error": f"Invalid target: '{target}'. Valid targets: {valid_targets}", + } + + # Check pagination bounds + if offset < 0: + return {"success": False, "error": "offset must be a non-negative integer."} + if limit < 1: + return {"success": False, "error": "limit must be a positive integer."} + + # Handle target: tags + if target == "tags": + all_tags = registry.get_available_tags() + # If query is provided, filter tags by name/description + if query: + q_lower = query.lower() + all_tags = [ + t + for t in all_tags + if q_lower in t.get("tag", "").lower() + or q_lower in t.get("description", "").lower() + ] + + total = len(all_tags) + page = all_tags[offset : offset + limit] + return { + "success": True, + "results": page, + "count": len(page), + "total": total, + "offset": offset, + "limit": limit, + "has_more": (offset + limit) < total, + } + + # Handle target: metrics or estimators + # Validate task if provided if task is not None: valid_tasks = registry.get_available_tasks() if task not in valid_tasks: @@ -53,7 +126,7 @@ def list_estimators_tool( + (f" Did you mean: {suggestions}?" if suggestions else ""), } - # Validate tag keys + # Validate tag keys if provided if tags is not None: valid_tag_keys = {t["tag"] for t in registry.get_available_tags()} invalid_keys = [k for k in tags if k not in valid_tag_keys] @@ -64,68 +137,53 @@ def list_estimators_tool( } return { "success": False, - "error": f"Invalid tag key(s): {invalid_keys}. Use get_available_tags to see valid keys.", + "error": f"Invalid tag key(s): {invalid_keys}. Use target='tags' to see valid keys.", "suggestions": {k: v[0] if v else None for k, v in suggestions.items()}, } - if query: - estimators = registry.search_estimators(query) + # Fetch base list of components + if target == "metrics": + # Metrics are components with task == "metric" + components = registry.get_all_estimators(task="metric") + else: + # Estimators are everything except metrics + components = [e for e in registry.get_all_estimators() if e.task != "metric"] + # Apply task filter if task: - estimators = [e for e in estimators if e.task == task] + components = [e for e in components if e.task == task] + # Apply tags filter if tags: - estimators = registry._filter_by_tags(estimators, tags) - else: - estimators = registry.get_all_estimators(task=task, tags=tags) - - total = len(estimators) - - if offset < 0: - return { - "success": False, - "error": "offset must be a non-negative integer.", - } - - if limit < 1: - return { - "success": False, - "error": "limit must be a positive integer.", - } + components = registry._filter_by_tags(components, tags) - page = estimators[offset : offset + limit] + # Apply query search if provided + if query: + q_lower = query.lower() + filtered_components = [] + for node in components: + name_lower = node.name.lower() + module_lower = node.module.lower() + doc_lower = node.docstring.lower() if node.docstring else "" + if q_lower in name_lower or q_lower in module_lower or q_lower in doc_lower: + filtered_components.append(node) + components = filtered_components + + # Pagination + total = len(components) + page = components[offset : offset + limit] results = [est.to_summary() for est in page] return { "success": True, - "estimators": results, + "results": results, "count": len(results), "total": total, "offset": offset, "limit": limit, "has_more": (offset + limit) < total, + "target": target, "task_filter": task, "tag_filter": tags, "query": query, } except Exception as e: - return { - "success": False, - "error": str(e), - } - - -def get_available_tasks() -> dict[str, Any]: - """Get list of available task types.""" - registry = get_registry() - return { - "success": True, - "tasks": registry.get_available_tasks(), - } - - -def get_available_tags() -> dict[str, Any]: - """Get list of all available capability tags.""" - registry = get_registry() - return { - "success": True, - "tags": registry.get_available_tags(), - } + return {"success": False, "error": str(e)} diff --git a/tests/test_core.py b/tests/test_core.py index 2326c2b4..fef8de09 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -33,7 +33,7 @@ def test_filter_by_task(self): def test_detection_estimators_in_registry(self): """Registry loads sktime `detector` scitype as MCP task `detection`.""" from sktime_mcp.registry.interface import get_registry - from sktime_mcp.tools.describe_estimator import search_estimators_tool + from sktime_mcp.tools.list_estimators import query_registry_tool registry = get_registry() detectors = registry.get_all_estimators(task="detection") @@ -43,7 +43,7 @@ def test_detection_estimators_in_registry(self): names = {e.name for e in detectors} assert "PyODDetector" in names or "BinarySegmentation" in names - anomaly_search = search_estimators_tool("anomaly", limit=50) + anomaly_search = query_registry_tool(target="estimators", query="anomaly", limit=50) assert anomaly_search["success"] assert anomaly_search["count"] > 2 @@ -184,42 +184,42 @@ def test_transformer_to_forecaster_valid(self): class TestTools: """Tests for MCP tools.""" - def test_list_estimators_tool(self): - """Test list_estimators tool.""" - from sktime_mcp.tools.list_estimators import list_estimators_tool + def test_query_registry_tool(self): + """Test query_registry tool.""" + from sktime_mcp.tools.list_estimators import query_registry_tool - result = list_estimators_tool(limit=5) + result = query_registry_tool(target="estimators", limit=5) assert result["success"] - assert "estimators" in result - assert len(result["estimators"]) <= 5 + assert "results" in result + assert len(result["results"]) <= 5 - def test_list_estimators_detection_task(self): + def test_query_registry_detection_task(self): """Test that detection estimators are returned when filtering by detection task.""" - from sktime_mcp.tools.list_estimators import list_estimators_tool + from sktime_mcp.tools.list_estimators import query_registry_tool - result = list_estimators_tool(task="detection", limit=100) + result = query_registry_tool(target="estimators", task="detection", limit=100) assert result["success"] assert result["total"] > 0, "There should be detection estimators" - assert all(e["task"] == "detection" for e in result["estimators"]), ( + assert all(e["task"] == "detection" for e in result["results"]), ( "All returned estimators should have task='detection'" ) - def test_detection_in_available_tasks(self): - """Test that detection appears in available tasks.""" - from sktime_mcp.tools.list_estimators import get_available_tasks + def test_query_registry_invalid_task(self): + """Test query_registry with an invalid task.""" + from sktime_mcp.tools.list_estimators import query_registry_tool - result = get_available_tasks() + result = query_registry_tool(target="estimators", task="invalid_task_name") - assert result["success"] - assert "detection" in result["tasks"], "detection should be a valid task" + assert not result["success"] + assert "error" in result - def test_describe_unknown_estimator(self): - """Test describing an unknown estimator.""" - from sktime_mcp.tools.describe_estimator import describe_estimator_tool + def test_describe_unknown_component(self): + """Test describing an unknown component.""" + from sktime_mcp.tools.describe_component import describe_component_tool - result = describe_estimator_tool("NotARealEstimator12345") + result = describe_component_tool("NotARealComponent12345") assert not result["success"] assert "error" in result @@ -488,42 +488,42 @@ def test_repeated_loads_handle_count_stays_constant(self): assert len(executor._data_handles) == 3 -class TestSearchEstimatorsLimit: - """Tests for the limit parameter validation in search_estimators_tool.""" +class TestQueryRegistryLimit: + """Tests for the limit and offset parameter validation in query_registry_tool.""" def test_limit_zero_returns_error(self): - """limit=0 should return an error, not an empty list.""" - from sktime_mcp.tools.describe_estimator import search_estimators_tool + """limit=0 should return an error.""" + from sktime_mcp.tools.list_estimators import query_registry_tool - result = search_estimators_tool("NaiveForecaster", limit=0) + result = query_registry_tool(target="estimators", limit=0) assert not result["success"] assert result["error"] == "limit must be a positive integer." - def test_limit_negative_one_returns_error(self): - """limit=-1 should return an error, not the last result.""" - from sktime_mcp.tools.describe_estimator import search_estimators_tool + def test_limit_negative_returns_error(self): + """Negative limit should return an error.""" + from sktime_mcp.tools.list_estimators import query_registry_tool - result = search_estimators_tool("NaiveForecaster", limit=-1) + result = query_registry_tool(target="estimators", limit=-5) assert not result["success"] assert result["error"] == "limit must be a positive integer." - def test_limit_negative_five_returns_error(self): - """limit=-5 should return an error, not the last 5 results.""" - from sktime_mcp.tools.describe_estimator import search_estimators_tool + def test_offset_negative_returns_error(self): + """Negative offset should return an error.""" + from sktime_mcp.tools.list_estimators import query_registry_tool - result = search_estimators_tool("NaiveForecaster", limit=-5) + result = query_registry_tool(target="estimators", offset=-1) assert not result["success"] - assert result["error"] == "limit must be a positive integer." + assert result["error"] == "offset must be a non-negative integer." def test_limit_valid_returns_results(self): """A positive limit should work correctly and cap results.""" pytest.importorskip("sktime", reason="sktime not installed in this environment") - from sktime_mcp.tools.describe_estimator import search_estimators_tool + from sktime_mcp.tools.list_estimators import query_registry_tool - result = search_estimators_tool("Forecaster", limit=3) + result = query_registry_tool(target="estimators", query="Forecaster", limit=3) assert result["success"] assert "results" in result diff --git a/tests/test_discovery_consolidated.py b/tests/test_discovery_consolidated.py new file mode 100644 index 00000000..c7d28eea --- /dev/null +++ b/tests/test_discovery_consolidated.py @@ -0,0 +1,53 @@ +"""Tests for consolidated discovery tools: query_registry and describe_component.""" + +from sktime_mcp.tools.describe_component import describe_component_tool +from sktime_mcp.tools.list_estimators import query_registry_tool + + +def test_query_registry_estimators(): + """Test query_registry for target='estimators'.""" + # Test getting estimators + res = query_registry_tool(target="estimators", limit=10) + assert res["success"] + assert "results" in res + assert len(res["results"]) > 0 + assert any(e["task"] == "forecasting" for e in res["results"]) + + +def test_query_registry_tags(): + """Test query_registry for target='tags'.""" + # Test getting tags + res = query_registry_tool(target="tags") + assert res["success"] + assert "results" in res + assert len(res["results"]) > 0 + tag_names = {t["tag"] for t in res["results"]} + assert "scitype:y" in tag_names or "capability:pred_int" in tag_names + + +def test_query_registry_metrics(): + """Test query_registry for target='metrics'.""" + # Test getting performance metrics + res = query_registry_tool(target="metrics", limit=20) + assert res["success"] + assert "results" in res + assert len(res["results"]) > 0 + # Metrics should have task == 'metric' + assert any(e["task"] == "metric" for e in res["results"]) + + +def test_describe_component_forecaster(): + """Test describe_component on a forecaster component.""" + res = describe_component_tool("NaiveForecaster") + assert res["success"] + assert res["name"] == "NaiveForecaster" + assert res["task"] == "forecasting" + assert "strategy" in res["parameters"] + + +def test_describe_component_metric(): + """Test describe_component on a metric component.""" + res = describe_component_tool("MeanAbsolutePercentageError") + assert res["success"] + assert res["name"] == "MeanAbsolutePercentageError" + assert res["task"] == "metric"