Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Presets #85

Merged
merged 2 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions config/config.yaml → configs/best_quality.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ save_artifacts:
enabled: False
append_timestamp: True
path: "./aga-artifacts"
# path: "s3://autogluon-assistant-agts/outputs/<run_id>/aga-artifacts/"
feature_transformers:
- _target_: autogluon_assistant.transformer.CAAFETransformer
eval_model: lightgbm
Expand All @@ -25,19 +24,18 @@ autogluon:
predictor_fit_kwargs:
verbosity: 2
presets: best_quality
time_limit: 60
time_limit: 14400
dynamic_stacking: True
num_bag_folds: 5
num_stack_levels: 2
llm:
# Note: bedrock is only supported in limited AWS regions
provider: bedrock
api_key_location: BEDROCK_API_KEY
model: "anthropic.claude-3-5-sonnet-20241022-v2:0" #"anthropic.claude-3-5-sonnet-20240620-v1:0"
model: "anthropic.claude-3-5-sonnet-20241022-v2:0"
# provider: openai
# api_key_location: OPENAI_API_KEY
# model: gpt-4o-2024-08-06
# model: gpt-3.5-turbo
max_tokens: 512
proxy_url: null
temperature: 0
Expand Down
29 changes: 29 additions & 0 deletions configs/high_quality.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
infer_eval_metric: True
detect_and_drop_id_column: False
task_preprocessors_timeout: 3600
save_artifacts:
enabled: False
append_timestamp: True
path: "./aga-artifacts"
feature_transformers:
autogluon:
predictor_init_kwargs: {}
predictor_fit_kwargs:
verbosity: 2
presets: high_quality
time_limit: 3600
dynamic_stacking: True
num_bag_folds: 5
num_stack_levels: 2
llm:
# Note: bedrock is only supported in limited AWS regions
provider: bedrock
api_key_location: BEDROCK_API_KEY
model: "anthropic.claude-3-5-sonnet-20241022-v2:0"
# provider: openai
# api_key_location: OPENAI_API_KEY
# model: gpt-4o-2024-08-06
max_tokens: 512
proxy_url: null
temperature: 0
verbose: True
29 changes: 29 additions & 0 deletions configs/medium_quality.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
infer_eval_metric: True
detect_and_drop_id_column: False
task_preprocessors_timeout: 3600
save_artifacts:
enabled: False
append_timestamp: True
path: "./aga-artifacts"
feature_transformers:
autogluon:
predictor_init_kwargs: {}
predictor_fit_kwargs:
verbosity: 2
presets: medium_quality
time_limit: 600
dynamic_stacking: True
num_bag_folds: 5
num_stack_levels: 2
llm:
# Note: bedrock is only supported in limited AWS regions
provider: bedrock
api_key_location: BEDROCK_API_KEY
model: "anthropic.claude-3-5-sonnet-20241022-v2:0"
# provider: openai
# api_key_location: OPENAI_API_KEY
# model: gpt-4o-2024-08-06
max_tokens: 512
proxy_url: null
temperature: 0
verbose: True
14 changes: 12 additions & 2 deletions src/autogluon_assistant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing_extensions import Annotated

from .assistant import TabularPredictionAssistant
from .constants import NO_ID_COLUMN_IDENTIFIED
from .constants import DEFAULT_QUALITY, NO_ID_COLUMN_IDENTIFIED, PRESETS
from .task import TabularPredictionTask
from .utils import load_config

Expand Down Expand Up @@ -74,6 +74,10 @@ def make_prediction_outputs(task: TabularPredictionTask, predictions: pd.DataFra

def run_assistant(
task_path: Annotated[str, typer.Argument(help="Directory where task files are included")],
presets: Annotated[
Optional[str],
typer.Option("--presets", "-p", help="Presets"),
] = None,
config_path: Annotated[
Optional[str],
typer.Option("--config-path", "-c", help="Path to the configuration file (config.yaml)"),
Expand All @@ -90,9 +94,15 @@ def run_assistant(
) -> str:
logging.info("Starting AutoGluon-Assistant")

if presets is None or presets not in PRESETS:
logging.info(f"Presets is not provided or invalid: {presets}")
presets = DEFAULT_QUALITY
logging.info(f"Using default presets: {presets}")
logging.info(f"Presets: {presets}")

# Load config with all overrides
try:
config = load_config(config_path, config_overrides)
config = load_config(presets, config_path, config_overrides)
logging.info("Successfully loaded config")
except Exception as e:
logging.error(f"Failed to load config: {e}")
Expand Down
8 changes: 8 additions & 0 deletions src/autogluon_assistant/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@
PROBLEM_TYPES = [BINARY, MULTICLASS, REGRESSION]
CLASSIFICATION_PROBLEM_TYPES = [BINARY, MULTICLASS]

# Presets/Configs
CONFIGS = "configs"
MEDIUM_QUALITY = "medium_quality"
HIGH_QUALITY = "high_quality"
BEST_QUALITY = "best_quality"
DEFAULT_QUALITY = BEST_QUALITY
PRESETS = [MEDIUM_QUALITY, HIGH_QUALITY, BEST_QUALITY]

# Metrics
ROC_AUC = "roc_auc"
LOG_LOSS = "log_loss"
Expand Down
14 changes: 10 additions & 4 deletions src/autogluon_assistant/utils/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@

from omegaconf import OmegaConf

from ..constants import CONFIGS

def _get_default_config_path() -> Path:

def _get_default_config_path(
presets: str,
) -> Path:
"""
Get default config folder under package root
Returns Path to the config.yaml file
"""
current_file = Path(__file__).parent.parent.parent.parent.absolute()
config_path = current_file / "config" / "config.yaml"
config_path = current_file / CONFIGS / f"{presets}.yaml"

if not config_path.exists():
raise ValueError(f"Config file not found at expected location: {config_path}")
Expand Down Expand Up @@ -77,7 +81,9 @@ def apply_overrides(config: Dict[str, Any], overrides: List[str]) -> Dict[str, A
return OmegaConf.merge(config, override_conf)


def load_config(config_path: Optional[str] = None, overrides: Optional[List[str]] = None) -> Dict[str, Any]:
def load_config(
presets: str, config_path: Optional[str] = None, overrides: Optional[List[str]] = None
) -> Dict[str, Any]:
"""
Load configuration from yaml file, merging with default config and applying overrides

Expand All @@ -92,7 +98,7 @@ def load_config(config_path: Optional[str] = None, overrides: Optional[List[str]
ValueError: If config file not found or invalid
"""
# Load default config
default_config_path = _get_default_config_path()
default_config_path = _get_default_config_path(presets)
logging.info(f"Loading default config from: {default_config_path}")
config = OmegaConf.load(default_config_path)

Expand Down
Loading