From e9213f8ed4138efe144652f6f766d4bacf426fec Mon Sep 17 00:00:00 2001 From: FANGAreNotGnu Date: Thu, 7 Nov 2024 08:48:01 +0000 Subject: [PATCH] improve code quality, pending tests --- src/autogluon_assistant/assistant.py | 196 ++++++-- src/autogluon_assistant/llm/llm.py | 212 ++++---- src/autogluon_assistant/predictor.py | 469 ++++++++++++++---- .../prompting/prompt_generator.py | 305 ++++++++---- .../task_inference/task_inference.py | 409 ++++++++++----- 5 files changed, 1106 insertions(+), 485 deletions(-) diff --git a/src/autogluon_assistant/assistant.py b/src/autogluon_assistant/assistant.py index 10d475e..5ea3a05 100644 --- a/src/autogluon_assistant/assistant.py +++ b/src/autogluon_assistant/assistant.py @@ -1,7 +1,8 @@ import logging import os import signal -from typing import Any, Dict, Union +from dataclasses import dataclass +from typing import Any, Dict, List, Type, Union from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf @@ -25,44 +26,82 @@ logger = logging.getLogger(__name__) -class timeout: - def __init__(self, seconds=1, error_message="Transform timed out"): - self.seconds = seconds - self.error_message = error_message +@dataclass +class TimeoutContext: + """Context manager for handling operation timeouts.""" + seconds: int + error_message: str = "Operation timed out" - def handle_timeout(self, signum, frame): + def handle_timeout(self, signum: int, frame: Any) -> None: + """Signal handler for timeout.""" raise TransformTimeoutError(self.error_message) - def __enter__(self): + def __enter__(self) -> 'TimeoutContext': signal.signal(signal.SIGALRM, self.handle_timeout) signal.alarm(self.seconds) + return self - def __exit__(self, type, value, traceback): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: signal.alarm(0) class TabularPredictionAssistant: - """A TabularPredictionAssistant performs a supervised tabular learning task""" + """ + A TabularPredictionAssistant that performs supervised tabular learning tasks. + + Attributes: + config (DictConfig): Configuration for the assistant + llm (Union[AssistantChatOpenAI, AssistantChatBedrock]): Language model instance + predictor (AutogluonTabularPredictor): AutoGluon predictor instance + feature_transformers_config (Any): Configuration for feature transformers + """ def __init__(self, config: DictConfig) -> None: + """ + Initialize the TabularPredictionAssistant. + + Args: + config (DictConfig): Configuration object containing necessary settings + """ self.config = config - self.llm: Union[AssistantChatOpenAI, AssistantChatBedrock] = LLMFactory.get_chat_model(config.llm) + self.llm = LLMFactory.get_chat_model(config.llm) self.predictor = AutogluonTabularPredictor(config.autogluon) self.feature_transformers_config = config.feature_transformers def describe(self) -> Dict[str, Any]: + """ + Get a description of the assistant's components. + + Returns: + Dict[str, Any]: Description of predictor, config, and LLM + """ return { "predictor": self.predictor.describe(), "config": OmegaConf.to_container(self.config), - "llm": self.llm.describe(), # noqa + "llm": self.llm.describe(), } - def handle_exception(self, stage: str, exception: Exception): + def handle_exception(self, stage: str, exception: Exception) -> None: + """ + Handle exceptions by raising them with additional context. + + Args: + stage (str): The processing stage where the exception occurred + exception (Exception): The original exception + + Raises: + Exception: Enhanced exception with stage information + """ raise Exception(str(exception), stage) - def inference_task(self, task: TabularPredictionTask) -> TabularPredictionTask: - logger.info("Task understanding starts...") - task_inference_preprocessors = [ + def _get_task_inference_preprocessors(self) -> List[Type]: + """ + Get the list of task inference preprocessors based on configuration. + + Returns: + List[Type]: List of preprocessor classes + """ + preprocessors = [ DescriptionFileNameInference, DataFileNameInference, LabelColumnInference, @@ -70,68 +109,121 @@ def inference_task(self, task: TabularPredictionTask) -> TabularPredictionTask: ] if self.config.detect_and_drop_id_column: - task_inference_preprocessors += [ + preprocessors.extend([ OutputIDColumnInference, TrainIDColumnInference, TestIDColumnInference, - ] + ]) + if self.config.infer_eval_metric: - task_inference_preprocessors += [EvalMetricInference] - for preprocessor_class in task_inference_preprocessors: + preprocessors.append(EvalMetricInference) + + return preprocessors + + def inference_task(self, task: TabularPredictionTask) -> TabularPredictionTask: + """ + Perform task inference using configured preprocessors. + + Args: + task (TabularPredictionTask): The initial task object + + Returns: + TabularPredictionTask: The processed task object + """ + logger.info("Task understanding starts...") + + for preprocessor_class in self._get_task_inference_preprocessors(): preprocessor = preprocessor_class(llm=self.llm) try: - with timeout( + with TimeoutContext( seconds=self.config.task_preprocessors_timeout, - error_message=f"Task inference preprocessing timed out: {preprocessor_class}", + error_message=f"Task inference preprocessing timed out: {preprocessor_class.__name__}" ): task = preprocessor.transform(task) except Exception as e: - self.handle_exception(f"Task inference preprocessing: {preprocessor_class}", e) + self.handle_exception(f"Task inference preprocessing: {preprocessor_class.__name__}", e) - bold_start = "\033[1m" - bold_end = "\033[0m" - - logger.info(f"{bold_start}Total number of prompt tokens:{bold_end} {self.llm.input_}") - logger.info(f"{bold_start}Total number of completion tokens:{bold_end} {self.llm.output_}") + self._log_token_usage() logger.info("Task understanding complete!") return task + def _log_token_usage(self) -> None: + """Log the token usage statistics.""" + bold_format = lambda text: f"\033[1m{text}\033[0m" + logger.info(f"{bold_format('Total number of prompt tokens:')} {self.llm.input_}") + logger.info(f"{bold_format('Total number of completion tokens:')} {self.llm.output_}") + def preprocess_task(self, task: TabularPredictionTask) -> TabularPredictionTask: - # instantiate and run task preprocessors, which infer the problem type, important filenames - # and columns as well as the feature extractors + """ + Preprocess the task using inference and feature transformers. + + Args: + task (TabularPredictionTask): The task to preprocess + + Returns: + TabularPredictionTask: The preprocessed task + """ task = self.inference_task(task) - if self.feature_transformers_config: - logger.info("Automatic feature generation starts...") - if "OPENAI_API_KEY" not in os.environ: - logger.info("No OpenAI API keys found, therefore, skip CAAFE") - fe_transformers = [ - instantiate(ft_config) - for ft_config in self.feature_transformers_config - if ft_config["_target_"] != "autogluon_assistant.transformer.CAAFETransformer" - ] - else: - fe_transformers = [instantiate(ft_config) for ft_config in self.feature_transformers_config] - for fe_transformer in fe_transformers: - try: - with timeout( - seconds=self.config.task_preprocessors_timeout, - error_message=f"Task preprocessing timed out: {fe_transformer.name}", - ): - task = fe_transformer.fit_transform(task) - except Exception as e: - self.handle_exception(f"Task preprocessing: {fe_transformer.name}", e) - logger.info("Automatic feature generation complete!") - else: + + if not self.feature_transformers_config: logger.info("Automatic feature generation is disabled.") + return task + + logger.info("Automatic feature generation starts...") + fe_transformers = self._get_feature_transformers() + + for fe_transformer in fe_transformers: + try: + with TimeoutContext( + seconds=self.config.task_preprocessors_timeout, + error_message=f"Task preprocessing timed out: {fe_transformer.name}" + ): + task = fe_transformer.fit_transform(task) + except Exception as e: + self.handle_exception(f"Task preprocessing: {fe_transformer.name}", e) + + logger.info("Automatic feature generation complete!") return task - def fit_predictor(self, task: TabularPredictionTask): + def _get_feature_transformers(self) -> List[Any]: + """ + Get the list of feature transformers based on configuration and environment. + + Returns: + List[Any]: List of instantiated feature transformers + """ + if "OPENAI_API_KEY" not in os.environ: + logger.info("No OpenAI API keys found, therefore, skip CAAFE") + return [ + instantiate(ft_config) + for ft_config in self.feature_transformers_config + if ft_config["_target_"] != "autogluon_assistant.transformer.CAAFETransformer" + ] + + return [instantiate(ft_config) for ft_config in self.feature_transformers_config] + + def fit_predictor(self, task: TabularPredictionTask) -> None: + """ + Fit the predictor on the given task. + + Args: + task (TabularPredictionTask): The task to fit the predictor on + """ try: self.predictor.fit(task) except Exception as e: self.handle_exception("Predictor Fit", e) def predict(self, task: TabularPredictionTask) -> Any: + """ + Make predictions using the fitted predictor. + + Args: + task (TabularPredictionTask): The task to make predictions for + + Returns: + Any: Predictions from the predictor + """ try: return self.predictor.predict(task) except Exception as e: diff --git a/src/autogluon_assistant/llm/llm.py b/src/autogluon_assistant/llm/llm.py index c6f311b..0d9a759 100644 --- a/src/autogluon_assistant/llm/llm.py +++ b/src/autogluon_assistant/llm/llm.py @@ -1,7 +1,7 @@ import logging import os import pprint -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import boto3 import botocore @@ -16,151 +16,186 @@ logger = logging.getLogger(__name__) -class AssistantChatOpenAI(ChatOpenAI, BaseModel): - """ - AssistantChatOpenAI is a subclass of ChatOpenAI that traces the input and output of the model. - """ +class ChatModelMixin(BaseModel): + """Base mixin class for chat models with common functionality.""" history_: List[Dict[str, Any]] = Field(default_factory=list) - input_: int = Field(default=0) - output_: int = Field(default=0) + input_tokens: int = Field(default=0, alias="input_") + output_tokens: int = Field(default=0, alias="output_") + + def update_token_usage(self, response: AIMessage) -> None: + """Update token usage based on response metadata.""" + if response.usage_metadata: + self.input_tokens += response.usage_metadata.get("input_tokens", 0) + self.output_tokens += response.usage_metadata.get("output_tokens", 0) + + def append_to_history(self, input_messages: List[BaseMessage], response: AIMessage) -> None: + """Append interaction to history.""" + self.history_.append({ + "input": [{"type": msg.type, "content": msg.content} for msg in input_messages], + "output": pprint.pformat(dict(response)), + "prompt_tokens": self.input_tokens, + "completion_tokens": self.output_tokens, + }) + + +class AssistantChatOpenAI(ChatOpenAI, ChatModelMixin): + """OpenAI chat model with input/output tracing capabilities.""" def describe(self) -> Dict[str, Any]: + """Return description of the model configuration and usage.""" return { "model": self.model_name, "proxy": self.openai_proxy, "history": self.history_, - "prompt_tokens": self.input_, - "completion_tokens": self.output_, + "prompt_tokens": self.input_tokens, + "completion_tokens": self.output_tokens, } - @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10)) - def invoke(self, *args, **kwargs): - input_: List[BaseMessage] = args[0] - response = super().invoke(*args, **kwargs) - - # Update token usage - if isinstance(response, AIMessage) and response.usage_metadata: - self.input_ += response.usage_metadata.get("input_tokens", 0) - self.output_ += response.usage_metadata.get("output_tokens", 0) - - self.history_.append( - { - "input": [{"type": msg.type, "content": msg.content} for msg in input_], - "output": pprint.pformat(dict(response)), - "prompt_tokens": self.input_, - "completion_tokens": self.output_, - } - ) + @retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=4, max=10) + ) + def invoke(self, input_messages: List[BaseMessage], **kwargs) -> AIMessage: + """Invoke the model with retry logic and tracking.""" + response = super().invoke(input_messages, **kwargs) + + if isinstance(response, AIMessage): + self.update_token_usage(response) + self.append_to_history(input_messages, response) + return response -class AssistantChatBedrock(ChatBedrock, BaseModel): - """ - AssistantChatBedrock is a subclass of ChatBedrock that traces the input and output of the model. - """ - - history_: List[Dict[str, Any]] = Field(default_factory=list) - input_: int = Field(default=0) - output_: int = Field(default=0) +class AssistantChatBedrock(ChatBedrock, ChatModelMixin): + """Bedrock chat model with input/output tracing capabilities.""" def describe(self) -> Dict[str, Any]: + """Return description of the model configuration and usage.""" return { "model": self.model_id, "history": self.history_, - "prompt_tokens": self.input_, - "completion_tokens": self.output_, + "prompt_tokens": self.input_tokens, + "completion_tokens": self.output_tokens, } - @retry(stop=stop_after_attempt(50), wait=wait_exponential(multiplier=1, min=4, max=10)) - def invoke(self, *args, **kwargs): - input_: List[BaseMessage] = args[0] + @retry( + stop=stop_after_attempt(50), + wait=wait_exponential(multiplier=1, min=4, max=10) + ) + def invoke(self, input_messages: List[BaseMessage], **kwargs) -> AIMessage: + """Invoke the model with retry logic and tracking.""" try: - response = super().invoke(*args, **kwargs) + response = super().invoke(input_messages, **kwargs) except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] == "ThrottlingException": raise e - else: - raise e + raise - # Update token usage - if isinstance(response, AIMessage) and response.usage_metadata: - self.input_ += response.usage_metadata.get("input_tokens", 0) - self.output_ += response.usage_metadata.get("output_tokens", 0) - - self.history_.append( - { - "input": [{"type": msg.type, "content": msg.content} for msg in input_], - "output": pprint.pformat(dict(response)), - "prompt_tokens": self.input_, - "completion_tokens": self.output_, - } - ) + if isinstance(response, AIMessage): + self.update_token_usage(response) + self.append_to_history(input_messages, response) + return response class LLMFactory: + """Factory class for creating and managing LLM instances.""" + + VALID_PROVIDERS = ["openai", "bedrock"] + @staticmethod def get_openai_models() -> List[str]: + """Fetch available OpenAI models.""" try: client = OpenAI() models = client.models.list() - return [model.id for model in models if model.id.startswith(("gpt-3.5", "gpt-4"))] + return [ + model.id + for model in models + if model.id.startswith(("gpt-3.5", "gpt-4")) + ] except Exception as e: - print(f"Error fetching OpenAI models: {e}") + logger.error(f"Error fetching OpenAI models: {e}") return [] @staticmethod def get_bedrock_models() -> List[str]: + """Fetch available Bedrock models.""" try: bedrock = boto3.client("bedrock", region_name="us-west-2") response = bedrock.list_foundation_models() - return [ + models = [ model["modelId"] for model in response["modelSummaries"] if model["modelId"].startswith("anthropic.claude") ] + if not models: + raise ValueError("No valid Bedrock models found") + return models except Exception as e: - print(f"Error fetching Bedrock models: {e}") + logger.error(f"Error fetching Bedrock models: {e}") return [] @classmethod - def get_valid_models(cls, provider): + def get_valid_models(cls, provider: str) -> List[str]: + """Get valid models for a given provider.""" + if provider not in cls.VALID_PROVIDERS: + raise ValueError(f"Invalid LLM provider: {provider}") + if provider == "openai": return cls.get_openai_models() - elif provider == "bedrock": - model_names = cls.get_bedrock_models() - assert len(model_names), "Check your bedrock keys" - return model_names - else: - raise ValueError(f"Invalid LLM provider: {provider}") + return cls.get_bedrock_models() @classmethod - def get_valid_providers(cls): - return ["openai", "bedrock"] + def get_chat_model( + cls, + config: DictConfig + ) -> Union[AssistantChatOpenAI, AssistantChatBedrock]: + """Create a chat model instance based on configuration.""" + provider = config.provider + model = config.model + + if provider not in cls.VALID_PROVIDERS: + raise ValueError( + f"Invalid provider: {provider}. Must be one of {cls.VALID_PROVIDERS}" + ) + + valid_models = cls.get_valid_models(provider) + if not valid_models: + raise ValueError(f"No valid models found for provider: {provider}") + + if model not in valid_models: + raise ValueError( + f"Invalid model: {model}. Must be one of {valid_models}" + ) - @staticmethod - def _get_openai_chat_model(config: DictConfig) -> AssistantChatOpenAI: - if config.api_key_location in os.environ: - api_key = os.environ[config.api_key_location] - else: - raise Exception("OpenAI API env variable not set") + if provider == "openai": + return cls._create_openai_model(config) + return cls._create_bedrock_model(config) - logger.info(f"AGA is using model {config.model} from OpenAI to assist you with the task.") + @staticmethod + def _create_openai_model(config: DictConfig) -> AssistantChatOpenAI: + """Create an OpenAI chat model instance.""" + if config.api_key_location not in os.environ: + raise ValueError(f"OpenAI API key not found in environment: {config.api_key_location}") + logger.info(f"Using OpenAI model: {config.model}") + return AssistantChatOpenAI( model_name=config.model, temperature=config.temperature, max_tokens=config.max_tokens, verbose=config.verbose, - openai_api_key=api_key, + openai_api_key=os.environ[config.api_key_location], openai_api_base=config.proxy_url, ) @staticmethod - def _get_bedrock_chat_model(config: DictConfig) -> AssistantChatBedrock: - logger.info(f"AGA is using model {config.model} from Bedrock to assist you with the task.") - + def _create_bedrock_model(config: DictConfig) -> AssistantChatBedrock: + """Create a Bedrock chat model instance.""" + logger.info(f"Using Bedrock model: {config.model}") + return AssistantChatBedrock( model_id=config.model, model_kwargs={ @@ -170,20 +205,3 @@ def _get_bedrock_chat_model(config: DictConfig) -> AssistantChatBedrock: region_name="us-west-2", verbose=config.verbose, ) - - @classmethod - def get_chat_model(cls, config: DictConfig) -> Union[AssistantChatOpenAI, AssistantChatBedrock]: - valid_providers = cls.get_valid_providers() - assert config.provider in valid_providers, f"{config.provider} is not a valid provider in: {valid_providers}" - - valid_models = cls.get_valid_models(config.provider) - assert ( - config.model in valid_models - ), f"{config.model} is not a valid model in: {valid_models} for provider {config.provider}" - - if config.provider == "openai": - return LLMFactory._get_openai_chat_model(config) - elif config.provider == "bedrock": - return LLMFactory._get_bedrock_chat_model(config) - else: - raise ValueError(f"Invalid LLM provider: {config.provider}") diff --git a/src/autogluon_assistant/predictor.py b/src/autogluon_assistant/predictor.py index ff97605..1fdef70 100644 --- a/src/autogluon_assistant/predictor.py +++ b/src/autogluon_assistant/predictor.py @@ -1,122 +1,387 @@ -"""Predictors solve tabular prediction tasks""" +"""Module for handling tabular machine learning prediction tasks. -import logging -from collections import defaultdict -from typing import Any, Dict +This module provides the TabularPredictionTask class which encapsulates data and metadata +for tabular machine learning tasks, including datasets and their associated metadata. +""" -import numpy as np -from autogluon.common.features.feature_metadata import FeatureMetadata -from autogluon.core.metrics import make_scorer -from autogluon.tabular import TabularDataset, TabularPredictor -from sklearn.metrics import mean_squared_log_error +import os +import shutil +from pathlib import Path +from typing import Any, Dict, List, Optional, Union, TypeVar, cast -from .constants import BINARY, CLASSIFICATION_PROBA_EVAL_METRIC, MULTICLASS -from .task import TabularPredictionTask +import joblib +import pandas as pd +import s3fs +from autogluon.tabular import TabularDataset -logger = logging.getLogger(__name__) +# Type aliases for better readability +DatasetType = Union[Path, pd.DataFrame, TabularDataset] +PathLike = Union[str, Path] +# Constants +TRAIN = "train" +TEST = "test" +OUTPUT = "output" -def rmsle_func(y_true, y_pred, **kwargs): - return np.sqrt(mean_squared_log_error(y_true, y_pred, **kwargs)) +T = TypeVar('T', bound='TabularPredictionTask') +class TabularPredictionTask: + """A class representing a tabular machine learning prediction task. + + This class contains data and metadata for a tabular machine learning task, + including datasets, problem type, and other relevant metadata. + + Attributes: + metadata (Dict[str, Any]): Task metadata including name, description, and configuration + filepaths (List[Path]): List of paths to task-related files + cache_data (bool): Whether to cache loaded data in memory + dataset_mapping (Dict[str, Union[Path, pd.DataFrame, TabularDataset]]): Mapping of dataset types to their data + """ -root_mean_square_logarithmic_error = make_scorer( - "root_mean_square_logarithmic_error", - rmsle_func, - optimum=1, - greater_is_better=False, -) + def __init__( + self, + filepaths: List[Path], + metadata: Dict[str, Any], + name: str = "", + description: str = "", + cache_data: bool = True, + ) -> None: + """Initialize a new TabularPredictionTask. + Args: + filepaths: List of paths to task-related files + metadata: Dictionary containing task metadata + name: Name of the task + description: Description of the task + cache_data: Whether to cache loaded data in memory + """ + self.metadata = { + "name": name, + "description": description, + "label_column": None, + "problem_type": None, + "eval_metric": None, + "test_id_column": None, + } + self.metadata.update(metadata) + + self.filepaths = filepaths + self.cache_data = cache_data + self.dataset_mapping: Dict[str, Optional[DatasetType]] = { + TRAIN: None, + TEST: None, + OUTPUT: None, + } + + def __repr__(self) -> str: + """Return a string representation of the task.""" + return ( + f"TabularPredictionTask(name={self.metadata['name']}, " + f"description={self.metadata['description'][:100]}, " + f"{len(self.dataset_mapping)} datasets)" + ) + + @staticmethod + def read_task_file(task_path: Path, filename_pattern: str, default_filename: str = "description.txt") -> str: + """Read contents of a task file, searching recursively in the task path. + + Args: + task_path: Base path to search for the file + filename_pattern: Pattern to match the filename + default_filename: Fallback filename if pattern isn't found + + Returns: + Contents of the found file as a string + """ + try: + matching_paths = sorted( + list(task_path.glob(filename_pattern)), + key=lambda x: len(x.parents), # top level files take precedence + ) + if not matching_paths: + return Path(task_path / default_filename).read_text() + return matching_paths[0].read_text() + except (FileNotFoundError, IndexError): + return "" -class Predictor: - def fit(self, task: TabularPredictionTask) -> "Predictor": - return self + @staticmethod + def save_artifacts( + full_save_path: PathLike, + predictor: Any, + train_data: pd.DataFrame, + test_data: pd.DataFrame, + sample_submission_data: pd.DataFrame, + ) -> None: + """Save model artifacts either locally or to S3. - def predict(self, task: TabularPredictionTask) -> Any: - raise NotImplementedError + Args: + full_save_path: Path where artifacts should be saved + predictor: AutoGluon TabularPredictor instance + train_data: Training data + test_data: Test data + sample_submission_data: Sample submission data + """ + artifacts = { + "trained_model": predictor, + "train_data": train_data, + "test_data": test_data, + "out_data": sample_submission_data, + } + + ag_model_dir = predictor.predictor.path + full_save_path_pkl_file = f"{full_save_path}/artifacts.pkl" + + if str(full_save_path).startswith("s3://"): + fs = s3fs.S3FileSystem() + with fs.open(full_save_path_pkl_file, "wb") as f: + joblib.dump(artifacts, f) + + s3_model_dir = f"{full_save_path}/{os.path.dirname(ag_model_dir)}/{os.path.basename(ag_model_dir)}" + fs.put(ag_model_dir, s3_model_dir, recursive=True) + else: + os.makedirs(str(full_save_path), exist_ok=True) + with open(full_save_path_pkl_file, "wb") as f: + joblib.dump(artifacts, f) - def fit_predict(self, task: TabularPredictionTask) -> Any: - return self.fit(task).predict(task) + local_model_dir = os.path.join(str(full_save_path), ag_model_dir) + shutil.copytree(ag_model_dir, local_model_dir, dirs_exist_ok=True) + @classmethod + def from_path(cls: type[T], task_root_dir: Path, name: Optional[str] = None) -> T: + """Create a TabularPredictionTask instance from a directory path. -class AutogluonTabularPredictor(Predictor): + Args: + task_root_dir: Root directory containing task files + name: Optional name for the task - def __init__(self, config: Any): - self.config = config - self.metadata: Dict[str, Any] = defaultdict(dict) - self.tabular_predictor: TabularPredictor = None + Returns: + A new TabularPredictionTask instance + """ + task_data_filenames = [] + for root, _, files in os.walk(task_root_dir): + for file in files: + relative_path = os.path.relpath(os.path.join(root, file), task_root_dir) + task_data_filenames.append(relative_path) - def save_dataset_details(self, task: TabularPredictionTask) -> None: - for key, data in ( - ("train", task.train_data), - ("test", task.test_data), - ): - self.metadata["dataset_summary"][key] = data.describe().to_dict() - self.metadata["feature_metadata_raw"][key] = FeatureMetadata.from_df(data).to_dict() - self.metadata["feature_missing_values"][key] = (data.isna().sum() / len(data)).to_dict() + return cls( + filepaths=[task_root_dir / fn for fn in task_data_filenames], + metadata={"name": name or task_root_dir.name}, + ) def describe(self) -> Dict[str, Any]: - return dict(self.metadata) - - def fit(self, task: TabularPredictionTask) -> "AutogluonTabularPredictor": - """Trains an AutoGluon TabularPredictor with parsed arguments. Saves trained predictor to - `self.predictor`. - - Raises - ------ - Exception - TabularPredictor fit failures - """ - eval_metric = task.eval_metric - if eval_metric == "root_mean_squared_logarithmic_error": - eval_metric = root_mean_square_logarithmic_error - - predictor_init_kwargs = { - "learner_kwargs": {"ignored_columns": task.columns_in_train_but_not_test}, - "label": task.label_column, - "eval_metric": eval_metric, - **self.config.predictor_init_kwargs, - } - predictor_fit_kwargs = self.config.predictor_fit_kwargs.copy() - - logger.info("Fitting AutoGluon TabularPredictor") - logger.info(f"predictor_init_kwargs: {predictor_init_kwargs}") - logger.info(f"predictor_fit_kwargs: {predictor_fit_kwargs}") - - if predictor_fit_kwargs.get("dynamic_stacking", False): - # Use config value for num_stack_levels - # Default 1 if dynamic_stacking is True - if predictor_fit_kwargs["num_stack_levels"] == 0: - predictor_fit_kwargs["num_stack_levels"] = 1 - logger.info( - f"Dynamic stacking is enabled; setting num_stack_levels={predictor_fit_kwargs['num_stack_levels']}" - ) + """Generate a description of the task including metadata and data statistics. - self.metadata |= { - "predictor_init_kwargs": predictor_init_kwargs, - "predictor_fit_kwargs": predictor_fit_kwargs, + Returns: + Dictionary containing task description and statistics + """ + description = { + "name": self.metadata["name"], + "description": self.metadata["description"], + "metadata": self.metadata, + "train_data": self.train_data.describe().to_dict(), + "test_data": self.test_data.describe().to_dict(), } - self.save_dataset_details(task) - self.predictor = TabularPredictor(**predictor_init_kwargs).fit(task.train_data, **predictor_fit_kwargs) - - self.metadata["leaderboard"] = self.predictor.leaderboard().to_dict() - return self - - def predict(self, task: TabularPredictionTask) -> TabularDataset: - """Calls `TabularPredictor.predict` or `TabularPredictor.predict_proba` on `self.transformed_test_data`. - Saves predictions to `self.predictions`. - - Raises - ------ - Exception - `TabularPredictor.predict` fails - """ - if task.eval_metric in CLASSIFICATION_PROBA_EVAL_METRIC and self.predictor.problem_type in [ - BINARY, - MULTICLASS, - ]: - return self.predictor.predict_proba( - task.test_data, as_multiclass=(self.predictor.problem_type == MULTICLASS) - ) - else: - return self.predictor.predict(task.test_data) + + if self.sample_submission_data is not None: + description["sample_submission_data"] = self.sample_submission_data.describe().to_dict() + + return description + + def get_filenames(self) -> List[str]: + """Get all filenames associated with the task. + + Returns: + List of filenames + """ + return [f.name for f in self.filepaths] + + def _set_task_files(self, dataset_name_mapping: Dict[str, Optional[Union[str, DatasetType]]]) -> None: + """Set the task files based on the provided mapping. + + Args: + dataset_name_mapping: Mapping of dataset names to their sources + """ + for key, value in dataset_name_mapping.items(): + if value is None: + self.dataset_mapping[key] = None + continue + + if isinstance(value, (pd.DataFrame, TabularDataset)): + self.dataset_mapping[key] = value + elif isinstance(value, (str, Path)): + filepath = ( + value if isinstance(value, Path) + else next( + (path for path in self.filepaths if path.name == value), + self.filepaths[0].parent / value + ) + ) + + if not filepath.is_file(): + raise ValueError(f"File {value} not found in task {self.metadata['name']}") + + if filepath.suffix in [".xlsx", ".xls"]: + self.dataset_mapping[key] = ( + pd.read_excel(filepath, engine="calamine") + if self.cache_data else filepath + ) + else: + self.dataset_mapping[key] = ( + TabularDataset(str(filepath)) + if self.cache_data else filepath + ) + else: + raise TypeError(f"Unsupported type for dataset_mapping: {type(value)}") + + @property + def train_data(self) -> TabularDataset: + """Get the training dataset. + + Returns: + Training dataset as TabularDataset + """ + return self.load_task_data(TRAIN) + + @train_data.setter + def train_data(self, data: Union[str, Path, pd.DataFrame, TabularDataset]) -> None: + """Set the training dataset. + + Args: + data: Training data to set + """ + self._set_task_files({TRAIN: data}) + + @property + def test_data(self) -> TabularDataset: + """Get the test dataset. + + Returns: + Test dataset as TabularDataset + """ + return self.load_task_data(TEST) + + @test_data.setter + def test_data(self, data: Union[str, Path, pd.DataFrame, TabularDataset]) -> None: + """Set the test dataset. + + Args: + data: Test data to set + """ + self._set_task_files({TEST: data}) + + @property + def sample_submission_data(self) -> Optional[TabularDataset]: + """Get the sample submission dataset. + + Returns: + Sample submission dataset as TabularDataset if available + """ + return self.load_task_data(OUTPUT) + + @sample_submission_data.setter + def sample_submission_data(self, data: Union[str, Path, pd.DataFrame, TabularDataset]) -> None: + """Set the sample submission dataset. + + Args: + data: Sample submission data to set + + Raises: + ValueError: If output data is already set + """ + if self.sample_submission_data is not None: + raise ValueError("Output data already set for task") + self._set_task_files({OUTPUT: data}) + + @property + def output_columns(self) -> Optional[List[str]]: + """Get the output dataset columns. + + Returns: + List of column names or None if not available + """ + if self.sample_submission_data is None: + return [self.label_column] if self.label_column else None + return self.sample_submission_data.columns.tolist() + + @property + def label_column(self) -> Optional[str]: + """Get the label column name. + + Returns: + Name of the label column or None if not set + """ + return self.metadata.get("label_column") or self._infer_label_column_from_sample_submission_data() + + @label_column.setter + def label_column(self, label_column: str) -> None: + """Set the label column name. + + Args: + label_column: Name of the label column + """ + self.metadata["label_column"] = label_column + + @property + def columns_in_train_but_not_test(self) -> List[str]: + """Get columns that exist in training data but not in test data. + + Returns: + List of column names + """ + return list(set(self.train_data.columns) - set(self.test_data.columns)) + + def _infer_label_column_from_sample_submission_data(self) -> Optional[str]: + """Infer the label column from sample submission data. + + Returns: + Inferred label column name or None if cannot be inferred + + Raises: + ValueError: If unable to infer the label column + """ + if self.output_columns is None: + return None + + relevant_output_cols = self.output_columns[1:] # Ignore first column (assumed to be ID) + existing_output_cols = [col for col in relevant_output_cols if col in self.train_data.columns] + + if len(existing_output_cols) == 1: + return existing_output_cols[0] + + output_set = {col.lower() for col in relevant_output_cols} + for col in self.train_data.columns: + unique_values = {str(val).lower() for val in self.train_data[col].unique() if pd.notna(val)} + if output_set == unique_values or output_set.issubset(unique_values): + return col + + raise ValueError("Unable to infer the label column. Please specify it manually.") + + def load_task_data(self, dataset_key: str) -> Optional[TabularDataset]: + """Load task data for a specific dataset type. + + Args: + dataset_key: Key identifying the dataset to load + + Returns: + Loaded dataset as TabularDataset or None if not available + + Raises: + ValueError: If dataset type is not found + TypeError: If file format is not supported + """ + if dataset_key not in self.dataset_mapping: + raise ValueError(f"Dataset type {dataset_key} not found for task {self.metadata['name']}") + + dataset = self.dataset_mapping[dataset_key] + if dataset is None: + return None + + if isinstance(dataset, pd.DataFrame): + return TabularDataset(dataset) + if isinstance(dataset, TabularDataset): + return dataset + + if dataset.suffix == ".json": + raise TypeError(f"File {dataset.name} has unsupported type: json") + + return TabularDataset(str(dataset)) diff --git a/src/autogluon_assistant/prompting/prompt_generator.py b/src/autogluon_assistant/prompting/prompt_generator.py index 9d81a34..411ef0d 100644 --- a/src/autogluon_assistant/prompting/prompt_generator.py +++ b/src/autogluon_assistant/prompting/prompt_generator.py @@ -1,10 +1,11 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Union +from typing import List, Optional, Union, Dict from langchain.output_parsers import ResponseSchema, StructuredOutputParser from langchain.prompts.chat import ChatPromptTemplate from langchain_core.messages import HumanMessage, SystemMessage +from pandas import DataFrame from ..constants import METRICS_DESCRIPTION, NO_FILE_IDENTIFIED, NO_ID_COLUMN_IDENTIFIED, PROBLEM_TYPES from ..utils import is_text_file, load_pd_quietly @@ -12,210 +13,310 @@ class PromptGenerator(ABC): - fields = None - - def __init__(self, data_description: str = ""): + """ + Abstract base class for generating prompts in data science tasks. + + This class provides a framework for creating structured prompts that can be used + to extract specific information about data science tasks. + + Attributes: + fields (List[str]): List of fields to be extracted from the prompt response + data_description (str): Description of the data science task + parser (StructuredOutputParser): Parser for structured output + """ + + fields: List[str] = None + + def __init__(self, data_description: str = "") -> None: + """ + Initialize the prompt generator. + + Args: + data_description: Description of the data science task + """ self.data_description = data_description - self.parser = self.create_parser() + self.parser = self._create_parser() @property - def system_prompt(self): + def system_prompt(self) -> str: + """Return the system prompt for the assistant.""" return "You are an expert assistant that parses information about data science tasks, such as data science competitions." @property - def basic_intro_prompt(self): + def basic_intro_prompt(self) -> str: + """Return the basic introduction prompt.""" return "The following sections contain descriptive information about a data science task:" @property - def data_description_prompt(self): + def data_description_prompt(self) -> str: + """Return the data description prompt.""" return f"# Data Description\n{self.data_description}" @abstractmethod def generate_prompt(self) -> str: + """Generate the complete prompt string.""" pass def get_field_parsing_prompt(self) -> str: + """Generate the prompt for field parsing instructions.""" return ( f"Based on the above information, provide the correct values for the following fields strictly " f"in valid JSON format: {', '.join(self.fields)}.\n\n" "Important:\n" "1. Return only valid JSON. No extra explanations, text, or comments.\n" "2. Ensure that the output can be parsed by a JSON parser directly.\n" - "3. Do not include any non-JSON text or formatting outside the JSON object." - '4. An example is \{"": ""\}' + "3. Do not include any non-JSON text or formatting outside the JSON object.\n" + '4. An example is {"": ""}' ) - def generate_chat_prompt(self): - chat_prompt = ChatPromptTemplate.from_messages( - [ - SystemMessage(content=self.system_prompt), - HumanMessage(content=self.generate_prompt()), - ] - ) - return chat_prompt + def generate_chat_prompt(self) -> ChatPromptTemplate: + """Generate a chat prompt template.""" + return ChatPromptTemplate.from_messages([ + SystemMessage(content=self.system_prompt), + HumanMessage(content=self.generate_prompt()), + ]) - def create_parser(self): + def _create_parser(self) -> StructuredOutputParser: + """Create a structured output parser for the fields.""" response_schemas = [ - ResponseSchema(name=field, description=f"The {field} for the task") for field in self.fields + ResponseSchema(name=field, description=f"The {field} for the task") + for field in self.fields ] return StructuredOutputParser.from_response_schemas(response_schemas) class DescriptionFileNamePromptGenerator(PromptGenerator): + """Generator for prompts to identify description and evaluation files.""" + fields = ["data_description_file", "evaluation_description_file"] - def __init__(self, filenames: list): + def __init__(self, filenames: List[str]) -> None: + """ + Initialize the description file name prompt generator. + + Args: + filenames: List of available filenames + """ super().__init__() self.filenames = filenames - def read_file_safely(self, filename: Path) -> Union[str, None]: + def read_file_safely(self, filename: Path) -> Optional[str]: + """ + Safely read a file's contents. + + Args: + filename: Path to the file + + Returns: + The file contents or None if the file cannot be read + """ try: return filename.read_text() except UnicodeDecodeError: return None def generate_prompt(self) -> str: - file_content_prompts = "# Available Files And Content in The File\n\n" + """Generate a prompt for identifying description files.""" + file_content_prompts = [] + file_content_prompts.append("# Available Files And Content in The File\n") + for filename in map(Path, self.filenames): if is_text_file(filename): content = self.read_file_safely(filename) if content is not None: - truncated_contents = content[:100].strip() - if len(content) > 100: - truncated_contents += "..." - file_content_prompts += f"File:\n\n{filename} Truncated Content:\n{truncated_contents}\n\n" - file_content_prompts += f"Please return the full path of the file to describe the problem settings, and response with the value {NO_FILE_IDENTIFIED} if there's no such file." - - return "\n\n".join( - [ - self.basic_intro_prompt, - file_content_prompts, - self.get_field_parsing_prompt(), - ] + truncated_contents = f"{content[:100].strip()}..." + file_content_prompts.append( + f"File:\n\n{filename} Truncated Content:\n{truncated_contents}\n" + ) + + file_content_prompts.append( + f"Please return the full path of the file to describe the problem settings, " + f"and response with the value {NO_FILE_IDENTIFIED} if there's no such file." ) + return "\n\n".join([ + self.basic_intro_prompt, + "\n".join(file_content_prompts), + self.get_field_parsing_prompt(), + ]) + class DataFileNamePromptGenerator(PromptGenerator): + """Generator for prompts to identify data files.""" + fields = ["train_data", "test_data", "sample_submission_data"] - def __init__(self, data_description: str, filenames: list): + def __init__(self, data_description: str, filenames: List[str]) -> None: + """ + Initialize the data file name prompt generator. + + Args: + data_description: Description of the data + filenames: List of available filenames + """ super().__init__(data_description) self.filenames = filenames def generate_prompt(self) -> str: - file_content_prompts = "# Available Data Files And Columns in The File\n\n" + """Generate a prompt for identifying data files.""" + file_content_prompts = ["# Available Data Files And Columns in The File\n"] + for filename in self.filenames: try: content = load_pd_quietly(filename) - truncated_columns = content.columns[:10].tolist() - if len(content.columns) > 10: - truncated_columns.append("...") - # truncated_columns_str = ", ".join(truncated_columns) - file_content_prompts += f"File:\n\n{filename}" # \n\nTruncated Columns:\n{truncated_columns_str}\n\n" + file_content_prompts.append(f"File:\n\n{filename}") except Exception as e: print( - f"Failed to load data as a pandas Dataframe in {filename} with following error (please ignore this if it is not supposed to be a data file): {e}" + f"Failed to load data as a pandas DataFrame in {filename} " + f"with following error (please ignore this if it is not supposed " + f"to be a data file): {e}" ) continue - file_content_prompts += f"Based on the data description, what are the training, test, and output data? The output file may contain keywords such as benchmark, submission, or output. Please return the full path of the data files as provided, and response with the value {NO_FILE_IDENTIFIED} if there's no such File." - - return "\n\n".join( - [ - self.basic_intro_prompt, - file_content_prompts, - self.get_field_parsing_prompt(), - ] + file_content_prompts.append( + f"Based on the data description, what are the training, test, and output data? " + f"The output file may contain keywords such as benchmark, submission, or output. " + f"Please return the full path of the data files as provided, and response with " + f"the value {NO_FILE_IDENTIFIED} if there's no such File." ) + return "\n\n".join([ + self.basic_intro_prompt, + "\n".join(file_content_prompts), + self.get_field_parsing_prompt(), + ]) + class LabelColumnPromptGenerator(PromptGenerator): + """Generator for prompts to identify label columns.""" + fields = ["label_column"] - def __init__(self, data_description: str, column_names: list): + def __init__(self, data_description: str, column_names: List[str]) -> None: + """ + Initialize the label column prompt generator. + + Args: + data_description: Description of the data + column_names: List of column names + """ super().__init__(data_description) self.column_names = get_outer_columns(column_names) def generate_prompt(self) -> str: - return "\n\n".join( - [ - self.basic_intro_prompt, - self.data_description_prompt, - f"Based on the data description, which one of these columns is likely to be the label column:\n{', '.join(self.column_names)}", - self.get_field_parsing_prompt(), - ] - ) + """Generate a prompt for identifying label columns.""" + return "\n\n".join([ + self.basic_intro_prompt, + self.data_description_prompt, + f"Based on the data description, which one of these columns is likely to be " + f"the label column:\n{', '.join(self.column_names)}", + self.get_field_parsing_prompt(), + ]) class ProblemTypePromptGenerator(PromptGenerator): + """Generator for prompts to identify problem types.""" + fields = ["problem_type"] def generate_prompt(self) -> str: - return "\n\n".join( - [ - self.basic_intro_prompt, - self.data_description_prompt, - f"Based on the information provided, identify the correct problem_type to be used from among these KEYS: {', '.join(PROBLEM_TYPES)}", - self.get_field_parsing_prompt(), - ] - ) - - -class IDColumnPromptGenerator(PromptGenerator): - fields = ["id_column"] - - def __init__(self, data_description: str, column_names: list, label_column: str): + """Generate a prompt for identifying problem types.""" + return "\n\n".join([ + self.basic_intro_prompt, + self.data_description_prompt, + f"Based on the information provided, identify the correct problem_type to be " + f"used from among these KEYS: {', '.join(PROBLEM_TYPES)}", + self.get_field_parsing_prompt(), + ]) + + +class BaseIDColumnPromptGenerator(PromptGenerator): + """Base class for ID column prompt generators.""" + + def __init__(self, data_description: str, column_names: List[str], label_column: str) -> None: + """ + Initialize the base ID column prompt generator. + + Args: + data_description: Description of the data + column_names: List of column names + label_column: Name of the label column + """ super().__init__(data_description) self.column_names = get_outer_columns(column_names) self.label_column = label_column def generate_prompt(self) -> str: - return "\n\n".join( - [ - self.basic_intro_prompt, - self.data_description_prompt, - f"Based on the data description, which one of these columns is likely to be the Id column:\n{', '.join(self.column_names)}", - f"If no reasonable Id column is present, for example if all the columns appear to be similarly named feature columns, " - f"response with the value {NO_ID_COLUMN_IDENTIFIED}", - f"ID columns can't be {self.label_column}", - self.get_field_parsing_prompt(), - ] - ) + """Generate a prompt for identifying ID columns.""" + return "\n\n".join([ + self.basic_intro_prompt, + self.data_description_prompt, + f"Based on the data description, which one of these columns is likely to be " + f"the Id column:\n{', '.join(self.column_names)}", + f"If no reasonable Id column is present, for example if all the columns appear " + f"to be similarly named feature columns, response with the value " + f"{NO_ID_COLUMN_IDENTIFIED}", + f"ID columns can't be {self.label_column}", + self.get_field_parsing_prompt(), + ]) + + +class IDColumnPromptGenerator(BaseIDColumnPromptGenerator): + """Generator for prompts to identify general ID columns.""" + fields = ["id_column"] -class TestIDColumnPromptGenerator(IDColumnPromptGenerator): +class TestIDColumnPromptGenerator(BaseIDColumnPromptGenerator): + """Generator for prompts to identify test data ID columns.""" fields = ["test_id_column"] -class TrainIDColumnPromptGenerator(IDColumnPromptGenerator): +class TrainIDColumnPromptGenerator(BaseIDColumnPromptGenerator): + """Generator for prompts to identify training data ID columns.""" fields = ["train_id_column"] -class OutputIDColumnPromptGenerator(IDColumnPromptGenerator): +class OutputIDColumnPromptGenerator(BaseIDColumnPromptGenerator): + """Generator for prompts to identify output data ID columns.""" fields = ["output_id_column"] class EvalMetricPromptGenerator(PromptGenerator): + """Generator for prompts to identify evaluation metrics.""" + fields = ["eval_metric"] - def __init__(self, data_description: str, metrics: str): + def __init__(self, data_description: str, metrics: str) -> None: + """ + Initialize the evaluation metric prompt generator. + + Args: + data_description: Description of the data + metrics: Available metrics + """ super().__init__(data_description) self.metrics = metrics def generate_prompt(self) -> str: - return "\n\n".join( - [ - self.basic_intro_prompt, - self.data_description_prompt, - f""" -Based on the information provided, identify the correct evaluation metric to be used from among these KEYS: -{', '.join(self.metrics)} + """Generate a prompt for identifying evaluation metrics.""" + metric_descriptions = [METRICS_DESCRIPTION[metric] for metric in self.metrics] + + return "\n\n".join([ + self.basic_intro_prompt, + self.data_description_prompt, + f""" +Based on the information provided, identify the correct evaluation metric to be used +from among these KEYS: {', '.join(self.metrics)} + The descriptions of these metrics are: -{', '.join([METRICS_DESCRIPTION[metric] for metric in self.metrics])} +{', '.join(metric_descriptions)} respectively. -If the exact metric is not in the list provided, then choose the metric that you think best approximates the one in the task description. -Only respond with the exact names of the metrics mentioned in KEYS. Do not respond with the metric descriptions. + +If the exact metric is not in the list provided, then choose the metric that you think +best approximates the one in the task description. + +Only respond with the exact names of the metrics mentioned in KEYS. Do not respond with +the metric descriptions. """, - self.get_field_parsing_prompt(), - ] - ) + self.get_field_parsing_prompt(), + ]) diff --git a/src/autogluon_assistant/task_inference/task_inference.py b/src/autogluon_assistant/task_inference/task_inference.py index a18690a..ca3ecd1 100644 --- a/src/autogluon_assistant/task_inference/task_inference.py +++ b/src/autogluon_assistant/task_inference/task_inference.py @@ -1,9 +1,11 @@ +from typing import Any, Dict, List, Optional, Union import difflib import logging -from typing import Any, Dict, List # Added Union for type hinting +from pathlib import Path from autogluon.core.utils.utils import infer_problem_type -from langchain_core.exceptions import OutputParserException # Updated import +from langchain_core.exceptions import OutputParserException +from langchain_core.messages import BaseMessage from autogluon_assistant.prompting import ( DataFileNamePromptGenerator, @@ -30,56 +32,105 @@ class TaskInference: - """Parses data and metadata of a task with the aid of an instruction-tuned LLM.""" + """Base class for parsing data and metadata of a task with the aid of an instruction-tuned LLM.""" - def __init__(self, llm, *args, **kwargs): + def __init__(self, llm: Any, *args: Any, **kwargs: Any) -> None: + """Initialize TaskInference. + + Args: + llm: Language model instance + *args: Additional positional arguments + **kwargs: Additional keyword arguments + """ super().__init__(*args, **kwargs) self.llm = llm - self.fallback_value = None - self.ignored_value: List[str] = [] # Added type hint + self.fallback_value: Optional[str] = None + self.ignored_value: List[str] = [] + self.prompt_generator: Optional[Any] = None + self.valid_values: Optional[List[str]] = None + + def initialize_task(self, task: TabularPredictionTask) -> None: + """Initialize task-specific attributes. - def initialize_task(self, task): + Args: + task: TabularPredictionTask instance + """ self.prompt_generator = None self.valid_values = None def log_value(self, key: str, value: Any, max_width: int = 1600) -> None: - """Logs a key-value pair with formatted output.""" + """Log a key-value pair with formatted output. + + Args: + key: Key to log + value: Value to log + max_width: Maximum width of the log message + """ if not value: logger.info(f"WARNING: Failed to identify the {key} of the task, it is set to None.") return - prefix = key # f"Identified the {key} of the task: " value_str = str(value).replace("\n", "\\n") + if len(key) + len(value_str) > max_width: + value_str = value_str[:max_width - len(key) - 3] + "..." - if len(prefix) + len(value_str) > max_width: - value_str = value_str[: max_width - len(prefix) - 3] + "..." + logger.info(f"\033[1m{key}\033[0m: {value_str}") - bold_start = "\033[1m" - bold_end = "\033[0m" + def transform(self, task: TabularPredictionTask) -> TabularPredictionTask: + """Transform the task using LLM inference. - logger.info(f"{bold_start}{prefix}{bold_end}: {value_str}") + Args: + task: TabularPredictionTask instance - def transform(self, task: TabularPredictionTask) -> TabularPredictionTask: + Returns: + Modified TabularPredictionTask instance + """ self.initialize_task(task) parser_output = self._chat_and_parse_prompt_output() - for k, v in parser_output.items(): - if v in self.ignored_value: - v = None - self.log_value(k, v) - setattr(task, k, self.post_process(task=task, value=v)) + for key, value in parser_output.items(): + if value in self.ignored_value: + value = None + self.log_value(key, value) + setattr(task, key, self.post_process(task=task, value=value)) return task - def post_process(self, task, value): + def post_process(self, task: TabularPredictionTask, value: Any) -> Any: + """Post-process the parsed value. + + Args: + task: TabularPredictionTask instance + value: Value to post-process + + Returns: + Processed value + """ return value - def parse_output(self, output): - assert self.prompt_generator is not None, "prompt_generator is not initialized" + def parse_output(self, output: BaseMessage) -> Dict[str, str]: + """Parse LLM output using the prompt generator's parser. + + Args: + output: LLM output message + + Returns: + Parsed output dictionary + """ + if not self.prompt_generator: + raise ValueError("prompt_generator is not initialized") return self.prompt_generator.parser.parse(output.content) def _chat_and_parse_prompt_output(self) -> Dict[str, str]: - """Chat with the LLM and parse the output""" + """Chat with the LLM and parse the output. + + Returns: + Dictionary containing parsed output + + Raises: + OutputParserException: If parsing fails + ValueError: If parsed value is not in valid_values + """ try: - chat_prompt = self.prompt_generator.generate_chat_prompt() + chat_prompt = self.prompt_generator.generate_chat_prompt() # type: ignore logger.debug(f"LLM chat_prompt:\n{chat_prompt.format_messages()}") output = self.llm.invoke(chat_prompt.format_messages()) logger.debug(f"LLM output:\n{output}") @@ -87,148 +138,192 @@ def _chat_and_parse_prompt_output(self) -> Dict[str, str]: except OutputParserException as e: logger.error(f"Failed to parse output: {e}") logger.error(self.llm.describe()) - raise e - - if self.valid_values is not None: - for key, parsed_value in parsed_output.items(): - if parsed_value not in self.valid_values: - # Currently only support single parsed value - if isinstance(parsed_value, str): - close_matches = difflib.get_close_matches(parsed_value, self.valid_values) - elif isinstance(parsed_value, list) and len(parsed_value) == 1: - parsed_value = parsed_value[0] - close_matches = difflib.get_close_matches(parsed_value, self.valid_values) - else: - logger.warning( - f"Unrecognized parsed value: {parsed_value} for key {key} parsed by the LLM. " - f"It has type: {type(parsed_value)}." - ) - close_matches = [] - - if len(close_matches) == 0: - if self.fallback_value: - logger.warning( - f"Unrecognized value: {parsed_value} for key {key} parsed by the LLM. " - f"Will use default value: {self.fallback_value}." - ) - parsed_output[key] = self.fallback_value - else: - raise ValueError(f"Unrecognized value: {parsed_value} for key {key} parsed by the LLM.") - else: - parsed_output[key] = close_matches[0] + raise + if self.valid_values: + parsed_output = self._validate_and_correct_output(parsed_output) return parsed_output + def _validate_and_correct_output(self, parsed_output: Dict[str, str]) -> Dict[str, str]: + """Validate and correct parsed output against valid values. + + Args: + parsed_output: Dictionary of parsed values + + Returns: + Validated and corrected output dictionary + + Raises: + ValueError: If no valid match is found and no fallback value is set + """ + if not self.valid_values: + return parsed_output + + for key, parsed_value in parsed_output.items(): + if parsed_value in self.valid_values: + continue + + close_matches = self._get_close_matches(parsed_value) + + if not close_matches: + if self.fallback_value: + logger.warning( + f"Unrecognized value: {parsed_value} for key {key} parsed by the LLM. " + f"Using default value: {self.fallback_value}." + ) + parsed_output[key] = self.fallback_value + else: + raise ValueError(f"Unrecognized value: {parsed_value} for key {key} parsed by the LLM.") + else: + parsed_output[key] = close_matches[0] + + return parsed_output + + def _get_close_matches(self, value: Union[str, List[str]]) -> List[str]: + """Get close matches for a value from valid_values. + + Args: + value: Value to match against valid_values + + Returns: + List of close matches + """ + if isinstance(value, str): + return difflib.get_close_matches(value, self.valid_values or []) + elif isinstance(value, list) and len(value) == 1: + return difflib.get_close_matches(value[0], self.valid_values or []) + else: + logger.warning(f"Unrecognized parsed value: {value} with type: {type(value)}.") + return [] + class DescriptionFileNameInference(TaskInference): - """Uses an LLM to locate the filenames of description files. - TODO: merge the logics with DataFileNameInference and add support for multiple files per field. - """ + """Infers description filenames using LLM.""" - def initialize_task(self, task): + def initialize_task(self, task: TabularPredictionTask) -> None: filenames = [str(path) for path in task.filepaths] self.valid_values = filenames + [NO_FILE_IDENTIFIED] self.fallback_value = NO_FILE_IDENTIFIED self.prompt_generator = DescriptionFileNamePromptGenerator(filenames=filenames) - def _read_descriptions(self, parser_output: dict) -> str: + def _read_descriptions(self, parser_output: Dict[str, Union[str, List[str]]]) -> str: + """Read and combine descriptions from identified files. + + Args: + parser_output: Dictionary containing file paths + + Returns: + Combined description string + """ description_parts = [] for key, file_paths in parser_output.items(): if isinstance(file_paths, str): - file_paths = [file_paths] # Convert single string to list + file_paths = [file_paths] for file_path in file_paths: if file_path == NO_FILE_IDENTIFIED: continue - else: - try: - with open(file_path, "r") as file: - content = file.read() - description_parts.append(f"{key}: {content}") - except FileNotFoundError: - continue - except IOError: - continue + try: + with open(file_path, "r") as file: + content = file.read() + description_parts.append(f"{key}: {content}") + except (FileNotFoundError, IOError): + continue return "\n\n".join(description_parts) def transform(self, task: TabularPredictionTask) -> TabularPredictionTask: self.initialize_task(task) parser_output = self._chat_and_parse_prompt_output() - descriptions_read = self._read_descriptions(parser_output) - if descriptions_read: - task.metadata["description"] = descriptions_read - self.log_value("description", descriptions_read) + descriptions = self._read_descriptions(parser_output) + if descriptions: + task.metadata["description"] = descriptions + self.log_value("description", descriptions) return task class DataFileNameInference(TaskInference): - """Uses an LLM to locate the filenames of the train, test, and output data, - and assigns them to the respective properties of the task. - """ + """Infers data filenames for train, test, and output data.""" - def initialize_task(self, task): + def initialize_task(self, task: TabularPredictionTask) -> None: filenames = [str(path) for path in task.filepaths] self.valid_values = filenames + [NO_FILE_IDENTIFIED] self.fallback_value = NO_FILE_IDENTIFIED self.ignored_value = [NO_FILE_IDENTIFIED] self.prompt_generator = DataFileNamePromptGenerator( - data_description=task.metadata["description"], filenames=filenames + data_description=task.metadata["description"], + filenames=filenames ) class LabelColumnInference(TaskInference): - def initialize_task(self, task): - column_names = list(task.train_data.columns) - self.valid_values = column_names + """Infers label column from data.""" + + def initialize_task(self, task: TabularPredictionTask) -> None: + self.valid_values = list(task.train_data.columns) self.prompt_generator = LabelColumnPromptGenerator( - data_description=task.metadata["description"], column_names=column_names + data_description=task.metadata["description"], + column_names=self.valid_values ) class ProblemTypeInference(TaskInference): - def initialize_task(self, task): + """Infers problem type from data.""" + + def initialize_task(self, task: TabularPredictionTask) -> None: self.valid_values = PROBLEM_TYPES - self.prompt_generator = ProblemTypePromptGenerator(data_description=task.metadata["description"]) + self.prompt_generator = ProblemTypePromptGenerator( + data_description=task.metadata["description"] + ) - def post_process(self, task, value): - # LLM may get confused between BINARY and MULTICLASS as it cannot see the whole label column + def post_process(self, task: TabularPredictionTask, value: str) -> str: if value in CLASSIFICATION_PROBLEM_TYPES: - problem_type_infered_by_autogluon = infer_problem_type(task.train_data[task.label_column], silent=True) - if problem_type_infered_by_autogluon in CLASSIFICATION_PROBLEM_TYPES: - value = problem_type_infered_by_autogluon + inferred_type = infer_problem_type(task.train_data[task.label_column], silent=True) + if inferred_type in CLASSIFICATION_PROBLEM_TYPES: + return inferred_type return value class BaseIDColumnInference(TaskInference): - def __init__(self, *args, **kwargs): + """Base class for ID column inference.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.valid_values = [] self.fallback_value = NO_ID_COLUMN_IDENTIFIED self.prompt_generator = None - def initialize_task(self, task, description=None): - if self.get_data(task) is None: + def initialize_task(self, task: TabularPredictionTask, description: Optional[str] = None) -> None: + data = self.get_data(task) + if data is None: return - column_names = list(self.get_data(task).columns) - # Assume ID column can only appear in first 3 columns - if len(column_names) >= 3: - column_names = column_names[:3] + column_names = list(data.columns)[:3] # Consider only first 3 columns self.valid_values = column_names + [NO_ID_COLUMN_IDENTIFIED] + if not description: description = task.metadata["description"] + self.prompt_generator = self.get_prompt_generator()( - data_description=description, column_names=column_names, label_column=task.metadata["label_column"] + data_description=description, + column_names=column_names, + label_column=task.metadata["label_column"] ) - def get_data(self, task): - pass + def get_data(self, task: TabularPredictionTask) -> Any: + """Get relevant data from task.""" + raise NotImplementedError("Subclasses must implement get_data") + + def get_prompt_generator(self) -> Any: + """Get appropriate prompt generator.""" + raise NotImplementedError("Subclasses must implement get_prompt_generator") - def get_prompt_generator(self): - pass + def get_id_column_name(self) -> str: + """Get name of ID column.""" + raise NotImplementedError("Subclasses must implement get_id_column_name") - def get_id_column_name(self): - pass + def process_id_column(self, task: TabularPredictionTask, id_column: str) -> Optional[str]: + """Process identified ID column.""" + raise NotImplementedError("Subclasses must implement process_id_column") def transform(self, task: TabularPredictionTask) -> TabularPredictionTask: if self.get_data(task) is None: @@ -241,10 +336,11 @@ def transform(self, task: TabularPredictionTask) -> TabularPredictionTask: if parser_output[id_column_name] == NO_ID_COLUMN_IDENTIFIED: logger.warning( - "Failed to infer ID column with data descriptions. " "Retry the inference without data descriptions." + "Failed to infer ID column with data descriptions. Retrying without descriptions." ) self.initialize_task( - task, description="Missing data description. Please infer the ID column based on given column names." + task, + description="Missing data description. Please infer the ID column based on given column names." ) parser_output = self._chat_and_parse_prompt_output() @@ -254,77 +350,126 @@ def transform(self, task: TabularPredictionTask) -> TabularPredictionTask: setattr(task, id_column_name, id_column) return task - def process_id_column(self, task, id_column): - pass - class TestIDColumnInference(BaseIDColumnInference): - def get_data(self, task): + """Infers test data ID column.""" + + def get_data(self, task: TabularPredictionTask) -> Any: return task.test_data - def get_prompt_generator(self): + def get_prompt_generator(self) -> Any: return TestIDColumnPromptGenerator - def get_id_column_name(self): + def get_id_column_name(self) -> str: return "test_id_column" - def process_id_column(self, task, id_column): + def process_id_column(self, task: TabularPredictionTask, id_column: str) -> str: if task.output_id_column != NO_ID_COLUMN_IDENTIFIED: - # if output data has id column but test data does not if id_column == NO_ID_COLUMN_IDENTIFIED: - if task.output_id_column not in task.test_data: - id_column = task.output_id_column - else: - id_column = "id_column" + id_column = ( + task.output_id_column + if task.output_id_column not in task.test_data + else "id_column" + ) new_test_data = task.test_data.copy() new_test_data[id_column] = task.sample_submission_data[task.output_id_column] task.test_data = new_test_data - return id_column class TrainIDColumnInference(BaseIDColumnInference): - def get_data(self, task): + """Infers training data ID column.""" + + def get_data(self, task: TabularPredictionTask) -> Any: return task.train_data - def get_prompt_generator(self): + def get_prompt_generator(self) -> Any: return TrainIDColumnPromptGenerator - def get_id_column_name(self): + def get_id_column_name(self) -> str: return "train_id_column" - def process_id_column(self, task, id_column): + def process_id_column(self, task: TabularPredictionTask, id_column: str) -> str: if id_column != NO_ID_COLUMN_IDENTIFIED: - new_train_data = task.train_data.copy() - new_train_data = new_train_data.drop(columns=[id_column]) + new_train_data = task.train_data.drop(columns=[id_column]) task.train_data = new_train_data logger.info(f"Dropping ID column {id_column} from training data.") task.metadata["dropped_train_id_column"] = True - return id_column class OutputIDColumnInference(BaseIDColumnInference): - def get_data(self, task): + """Infers output data ID column.""" + + def get_data(self, task: TabularPredictionTask) -> Any: + """Get sample submission data from task. + + Args: + task: TabularPredictionTask instance + + Returns: + Sample submission data + """ return task.sample_submission_data - def get_prompt_generator(self): + def get_prompt_generator(self) -> Any: + """Get OutputIDColumnPromptGenerator. + + Returns: + OutputIDColumnPromptGenerator class + """ return OutputIDColumnPromptGenerator - def get_id_column_name(self): + def get_id_column_name(self) -> str: + """Get output ID column name. + + Returns: + Name of output ID column + """ return "output_id_column" - def process_id_column(self, task, id_column): + def process_id_column(self, task: TabularPredictionTask, id_column: str) -> str: + """Process output ID column (no processing needed). + + Args: + task: TabularPredictionTask instance + id_column: Identified ID column name + + Returns: + Original ID column name + """ return id_column class EvalMetricInference(TaskInference): - def initialize_task(self, task): + """Infers evaluation metric based on problem type.""" + + def initialize_task(self, task: TabularPredictionTask) -> None: + """Initialize evaluation metric inference. + + Determines available metrics based on problem type and sets up the prompt generator. + + Args: + task: TabularPredictionTask instance + """ problem_type = task.problem_type - self.metrics = METRICS_DESCRIPTION.keys() if problem_type is None else METRICS_BY_PROBLEM_TYPE[problem_type] + + # Determine available metrics based on problem type + self.metrics = ( + list(METRICS_DESCRIPTION.keys()) + if problem_type is None + else METRICS_BY_PROBLEM_TYPE[problem_type] + ) + + # Set valid values for metric validation self.valid_values = self.metrics + + # Set fallback value if problem type is available if problem_type: self.fallback_value = METRICS_BY_PROBLEM_TYPE[problem_type][0] + + # Initialize prompt generator with available metrics self.prompt_generator = EvalMetricPromptGenerator( - data_description=task.metadata["description"], metrics=self.metrics + data_description=task.metadata["description"], + metrics=self.metrics )