diff --git a/docs/advanced_usage/13_adaptive_capping.md b/docs/advanced_usage/13_adaptive_capping.md new file mode 100644 index 000000000..74059763a --- /dev/null +++ b/docs/advanced_usage/13_adaptive_capping.md @@ -0,0 +1,107 @@ +# Adaptive Capping + +Adaptive capping is a feature that can be used to speedup the evaluation of candidate configurations when the objective +is to minimize runtime of an algorithm across a set of instances. The basic idea is to terminate unpromising candidates +early and adapting the timeout for solving a single instance dynamically based on the incumbent's runtime and the. +runtime already used by the challenging configuration. + +## Theoretical Background + +When comparing a challenger configuration with the current incumbent for a (sub-)set of instances, we already know how +much cost (in terms of runtime) was incurred by the incumbent to solve the set of instances. As soon as the challenger +configuration exceeds the cost of the incumbent, it is evident that the challenger will not become the new incumbent +since the costs accumulate over time and are strictly positive, i.e., solving an instance cannot have negative runtime. + +Example: +*Let the incumbent be evaluated for two instances with observed runtimes 3s and 4s. When a challenger configuration is +evaluated and compared against the incumbent, it is first evaluated on a first instance. For example, we observe a +runtime of 2s. As the challenger appears to be a promising configuration, its evaluation is intensified and the budget +is doubled, i.e., the budget is increased to 2. For solving the second instance, adaptive capping will allow a timeout +of 5s since the sum of runtimes for the incumbent is 7s and the challenger used up 2s for solving the first instance so +far so that 5s remain until the costs of the incumbent are exceeded. Even if the challenger configuration would need 10s +to solve the second instance, its execution would be aborted. In this example, by adaptive capping we thus save 5s of +evaluation costs for the challenger to notice that it will not replace the current incumbent.* + +In combination with random online aggressive racing, we can further speedup the evaluation of challenger configurations +as we increase the horizon for adaptive capping step by step with every step of intensification. Note that +intensification will double the number of instances to which the challenger configuration (and eventually also the +incumbent configuration) are applied to. Furthermore, to increase the trust into the current incumbent, the incumbent is +regularly subject to intensification. + + +## Setting up Adaptive Capping + +To achieve this, the user must take active care in the termination of their target function. +The capped problem.train will receive a budget keyword argument, detailing the seconds allocated to the configuration. +Below is an example of a capped problem that will return the used budget if the computation exceeds the budget. + + +```python + + class TimeoutException(Exception): + pass + + + @contextmanager + def timeout(seconds): + def handler(signum, frame): + raise TimeoutException(f"Function call exceeded timeout of {seconds} seconds") + + # Set the signal handler for the alarm signal + signal.signal(signal.SIGALRM, handler) + signal.alarm(seconds) # Schedule an alarm after the given number of seconds + + try: + yield + finally: + # Cancel the alarm if the block finishes before timeout + signal.alarm(0) + + + class CappedProblem: + @property + def configspace(self) -> ConfigurationSpace: + ... + + def train(self, config: Configuration, instance:str, budget, seed: int = 0) -> float: + + try: + with timeout(int(math.ceil(budget))): + start_time = time.time() + ... # heavy computation + runtime = time.time() - start_time + return runtime + except TimeoutException as e: + print(f"Timeout for configuration {config} with runtime budget {budget}") + return budget # here the runtime is capped and we return the used budget. +``` + +In order to enable adaptive capping in smac, we need to create [problem instances](4_instances.md) to optimize over and specify a +global runtime cutoff in the intensifier. Then we optimize as usual. + + +```python +from smac.intensifier import Intensifier +from smac.scenario.scenario import Scenario + +scenario = Scenario( + capped_problem.configspace, + ... + instances=['1', '2', '3'], # add problem instances we want to solve + instance_features={'1': [1], '2': [2], '3': [3]} # in the absence of actual features add dummy features for identification +) + +intensifier = Intensifier( +scenario, +runtime_cutoff=10 # specify an absolute runtime cutoff (sum over instances) never to be exceeded +) + +smac = HyperparameterOptimizationFacade( + scenario, + capped_problem.train, + intensifier=intensifier, + ... +) + +incumbent = smac.optimize() +``` diff --git a/examples/1_basics/8_adaptive_capping.py b/examples/1_basics/8_adaptive_capping.py new file mode 100644 index 000000000..3bd346c3a --- /dev/null +++ b/examples/1_basics/8_adaptive_capping.py @@ -0,0 +1,103 @@ +"""Adaptive Capping +# Flags: doc-Runnable + +Adaptive capping is often used in optimization algorithms, particularly in +scenarios where the time taken to evaluate solutions can vary significantly. +For more details on adaptive capping, consult the [info page adaptive capping](../../../advanced_usage/13_adaptive_capping.html). + +""" +import math +import time + +import signal +from contextlib import contextmanager + +from smac.runhistory import InstanceSeedBudgetKey, TrialInfo + +import warnings + +from ConfigSpace import Categorical, Configuration, ConfigurationSpace, Float + + +class TimeoutException(Exception): + pass + +@contextmanager +def timeout(seconds): + def handler(signum, frame): + raise TimeoutException(f"Function call exceeded timeout of {seconds} seconds") + + # Set the signal handler for the alarm signal + signal.signal(signal.SIGALRM, handler) + signal.alarm(seconds) # Schedule an alarm after the given number of seconds + + try: + yield + finally: + # Cancel the alarm if the block finishes before timeout + signal.alarm(0) + + +class CappedProblem: + @property + def configspace(self) -> ConfigurationSpace: + cs = ConfigurationSpace(seed=0) + x0 = Float("x0", (0, 5), default=5, log=False) + x1 = Float("x1", (0, 7), default=7, log=False) + cs.add_hyperparameters([x0, x1]) + return cs + + def train(self, config: Configuration, instance:str, budget, seed: int = 0) -> float: + x0 = config["x0"] + x1 = config["x1"] + + try: + with timeout(int(math.ceil(budget))): + runtime = 0.5 * x1 + 0.5 * x0 * int(instance) + time.sleep(runtime) + return runtime + except TimeoutException as e: + print(f"Timeout for configuration {config} with runtime budget {budget}") + return budget # FIXME: what should be returned here? + + +if __name__ == '__main__': + from smac import HyperparameterOptimizationFacade, RunHistory + from smac import Scenario + + from smac.intensifier import Intensifier + + capped_problem = CappedProblem() + + scenario = Scenario( + capped_problem.configspace, + walltime_limit=3600, # After 200 seconds, we stop the hyperparameter optimization + n_trials=500, # Evaluate max 500 different trials + instances=['1', '2', '3'], + instance_features={'1': [1], '2': [2], '3': [3]} + ) + + # We want to run five random configurations before starting the optimization. + initial_design = HyperparameterOptimizationFacade.get_initial_design(scenario, n_configs=5) + + intensifier = Intensifier(scenario, runtime_cutoff=10, adaptive_capping_slackfactor=1.2) + + # Create our SMAC object and pass the scenario and the train method + smac = HyperparameterOptimizationFacade( + scenario, + capped_problem.train, + initial_design=initial_design, + intensifier=intensifier, + overwrite=True, + ) + + # Let's optimize + incumbent = smac.optimize() + + # Get cost of default configuration + default_cost = smac.validate(capped_problem.configspace.get_default_configuration()) + print(f"Default cost ({intensifier.__class__.__name__}): {default_cost}") + + # Let's calculate the cost of the incumbent + incumbent_cost = smac.validate(incumbent) + print(f"Incumbent cost ({intensifier.__class__.__name__}): {incumbent_cost}") diff --git a/examples/1_basics/8_warmstart.py b/examples/1_basics/8_warmstart.py index 209b0b83e..14d75c6e3 100644 --- a/examples/1_basics/8_warmstart.py +++ b/examples/1_basics/8_warmstart.py @@ -3,7 +3,7 @@ With the ask and tell interface, we can support warmstarting SMAC. We can communicate rich information about the previous trials to SMAC using `TrialInfo` and `TrialValue` instances. -For more details on ask and tell consult the [info page ask-and-tell](../../../advanced_usage/5_ask_and_tell). +For more details on ask and tell consult the [info page ask-and-tell](../../../advanced_usage/5_ask_and_tell.html). """ from __future__ import annotations diff --git a/mkdocs.yaml b/mkdocs.yaml index 1cd2ff26c..013fef207 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -207,6 +207,7 @@ nav: - "advanced_usage/10_continue.md" - "advanced_usage/11_reproducibility.md" - "advanced_usage/12_optimizations.md" + - "advanced_usage/13_adaptive_capping.md" # Auto generated with docs/examples_runner.py - Examples: "examples/" # Auto generated with docs/api_generator.py diff --git a/smac/facade/abstract_facade.py b/smac/facade/abstract_facade.py index 9d3912061..8c4cc1429 100644 --- a/smac/facade/abstract_facade.py +++ b/smac/facade/abstract_facade.py @@ -119,6 +119,7 @@ def __init__( multi_objective_algorithm: AbstractMultiObjectiveAlgorithm | None = None, runhistory_encoder: AbstractRunHistoryEncoder | None = None, config_selector: ConfigSelector | None = None, + runtime_cutoff: int | None = None, logging_level: int | Path | Literal[False] | None = None, callbacks: list[Callback] = None, overwrite: bool = False, @@ -175,6 +176,7 @@ def __init__( self._runhistory = runhistory self._runhistory_encoder = runhistory_encoder self._config_selector = config_selector + self._runtime_cutoff = runtime_cutoff self._callbacks = callbacks self._overwrite = overwrite @@ -485,4 +487,7 @@ def _get_signature_arguments(self) -> list[str]: if self._intensifier.uses_instances: arguments += ["instance"] + if self._intensifier.uses_cutoffs: + arguments += ["cutoff"] + return arguments diff --git a/smac/facade/algorithm_configuration_facade.py b/smac/facade/algorithm_configuration_facade.py index bc1072300..d0173d142 100644 --- a/smac/facade/algorithm_configuration_facade.py +++ b/smac/facade/algorithm_configuration_facade.py @@ -99,12 +99,7 @@ def get_acquisition_maximizer( # type: ignore return optimizer @staticmethod - def get_intensifier( - scenario: Scenario, - *, - max_config_calls: int = 2000, - max_incumbents: int = 10, - ) -> Intensifier: + def get_intensifier(scenario: Scenario, *, max_config_calls: int = 2000, max_incumbents: int = 10) -> Intensifier: """Returns ``Intensifier`` as intensifier. Supports budgets. Parameters @@ -115,11 +110,7 @@ def get_intensifier( max_incumbents : int, defaults to 10 How many incumbents to keep track of in the case of multi-objective. """ - return Intensifier( - scenario=scenario, - max_config_calls=max_config_calls, - max_incumbents=max_incumbents, - ) + return Intensifier(scenario=scenario, max_config_calls=max_config_calls, max_incumbents=max_incumbents) @staticmethod def get_initial_design( # type: ignore diff --git a/smac/intensifier/abstract_intensifier.py b/smac/intensifier/abstract_intensifier.py index 0eb3f8dae..296d53048 100644 --- a/smac/intensifier/abstract_intensifier.py +++ b/smac/intensifier/abstract_intensifier.py @@ -213,7 +213,7 @@ def uses_seeds(self) -> bool: @abstractmethod def uses_budgets(self) -> bool: """If the intensifier needs to make use of budgets.""" - raise NotImplementedError + return False @property @abstractmethod @@ -221,6 +221,12 @@ def uses_instances(self) -> bool: """If the intensifier needs to make use of instances.""" raise NotImplementedError + @property + @abstractmethod + def uses_cutoffs(self) -> bool: + """If the intensifier needs to make use of cutoffs.""" + raise NotImplementedError + @property def incumbents_changed(self) -> int: """How often the incumbents have changed.""" @@ -530,13 +536,41 @@ def update_incumbents(self, config: Configuration) -> None: # 2) Highest budget: We only want to compare the configs if they are evaluated on the highest budget. # Here we do actually care about the budgets. Please see the ``get_instance_seed_budget_keys`` method from # Successive Halving to get more information. - # Noitce: compare=True only takes effect when subclass implemented it. -- e.g. in SH it - # will remove the budgets from the keys. + # Notice: compare=True only takes effect when subclass implemented it. -- e.g. in SH it will remove the budgets + # from the keys. config_isb_comparison_keys = self.get_instance_seed_budget_keys(config, compare=True) # Find the lowest intersection of instance-seed-budget keys for all incumbents. config_incumbent_isb_comparison_keys = self.get_incumbent_instance_seed_budget_keys(compare=True) # Now we have to check if the new config has been evaluated on the same keys as the incumbents + logger.debug( + f"Validate whether we have overlap in the keys evaluated for the challenger config " + f"{config_isb_comparison_keys} and the incumbent config {config_incumbent_isb_comparison_keys}. If the " + f"challenger is dominated, reject id." + ) + + if not self.uses_budgets and all( + [key in config_incumbent_isb_comparison_keys for key in config_isb_comparison_keys] + ): + logger.debug( + "Check on the currently evaluated instances whether the challenger is dominated by the incumbent" + ) + + # determine challenger costs + challenger_costs = self.runhistory.average_cost(config, config_isb_comparison_keys) + + # check the list of incumbents whether any of the incumbents dominates the current challenger + for inc in incumbents: + # determine incumbent costs + inc_costs = self.runhistory.average_cost(inc, config_isb_comparison_keys) + # check dominance + is_dominated = not np.any(np.array([challenger_costs]) < np.array([inc_costs])) + + # if challenger config is dominated by the incumbent, reject it + if is_dominated: + logger.debug(f"Challenger config {config_hash} is dominated by incumbent {get_config_hash(inc)}.") + self._add_rejected_config(config_id) + if not all([key in config_isb_comparison_keys for key in config_incumbent_isb_comparison_keys]): # We can not tell if the new config is better/worse than the incumbents because it has not been # evaluated on the necessary trials diff --git a/smac/intensifier/intensifier.py b/smac/intensifier/intensifier.py index 676161270..58a0abd90 100644 --- a/smac/intensifier/intensifier.py +++ b/smac/intensifier/intensifier.py @@ -1,12 +1,17 @@ from __future__ import annotations -from typing import Any, Iterator +from typing import Any, Iterator, Union +import warnings +from collections import defaultdict + +import numpy as np from ConfigSpace import Configuration +from sphinx.writers.latex import UnsupportedError from smac.intensifier.abstract_intensifier import AbstractIntensifier from smac.runhistory import TrialInfo -from smac.runhistory.dataclasses import InstanceSeedBudgetKey +from smac.runhistory.dataclasses import InstanceSeedBudgetKey, InstanceSeedKey, TrialKey from smac.scenario import Scenario from smac.utils.configspace import get_config_hash from smac.utils.logging import get_logger @@ -34,6 +39,8 @@ class Intensifier(AbstractIntensifier): Parameters ---------- + scenario : Scenario + The scenario defining the optimization problem. max_config_calls : int, defaults to 3 Maximum number of configuration evaluations. Basically, how many instance-seed keys should be maxed evaluated for a configuration. @@ -72,11 +79,15 @@ def uses_seeds(self) -> bool: # noqa: D102 def uses_budgets(self) -> bool: # noqa: D102 return False + @property + def uses_cutoffs(self) -> bool: + """If the intensifier needs to make use of cutoffs.""" + return self._scenario.runtime_cutoff is not None or self._scenario.adaptive_capping + @property def uses_instances(self) -> bool: # noqa: D102 if self._scenario.instances is None: return False - return True def get_state(self) -> dict[str, Any]: # noqa: D102 @@ -204,7 +215,7 @@ def __iter__(self) -> Iterator[TrialInfo]: if len(trials) > 0: fails = -1 logger.debug( - f"--- Yielding trial {len(individual_incumbent_isb_keys)+1} of " + f"--- Yielding trial {len(individual_incumbent_isb_keys) + 1} of " f"{self._max_config_calls} from incumbent {incumbent_hash}..." ) yield trials[0] @@ -267,6 +278,7 @@ def __iter__(self) -> Iterator[TrialInfo]: # TODO: What to do if there are no incumbent instances? (Use-case: call multiple asks) + logger.debug(f"Get next trials for config {config_hash} and budget {N} from keys {isk_keys}") trials = self._get_next_trials(config, N=N, from_keys=isk_keys) logger.debug(f"--- Yielding {len(trials)} trials to evaluate config {config_hash}...") for trial in trials: @@ -372,4 +384,144 @@ def _get_next_trials( for is_key in is_keys: trials.append(TrialInfo(config=config, instance=is_key.instance, seed=is_key.seed)) + if self.uses_cutoffs and bool(trials): + # We need to adapt the budget to the runtime cutoff + cutoffs = [self._get_adaptivecapping_budget(t.config, on_keys=is_keys) for t in trials] + + # convert existing trials to new trials with adapted budget + trials = [ + TrialInfo(config=t.config, instance=t.instance, seed=t.seed, additional_info={"cutoff": b}) + for b, t in zip(cutoffs, trials) + ] + return trials + + def _get_adaptivecapping_budget( + self, + challenger: Configuration, + on_keys: list[InstanceSeedKey], + ) -> float: + """Adaptive capping: Compute cutoff based on cost so far used for incumbent and reduce + cutoff for next run of challenger accordingly. + + Warning: + For concurrent runs, the budget will be determined for a challenger x instance + combination at the moment the challenger is considered for the instance, ignorant of + the runtime cost of the currently running instances of the same configuration. + + !Only applicable if self.run_obj_time + + !Only applicable in single-objective scenarios where the only cost function is run_obj_time + + !runs on incumbent should be superset of the runs performed for the + challenger + + Parameters + ---------- + challenger : Configuration + Configuration which challenges incumbent + + inc_sum_cost: float + Sum of runtimes of all incumbent runs + + Returns + ------- + cutoff: float + Adapted cutoff + """ + # cost used by challenger for going over all its runs + # should be subset of runs of incumbent (not checked for efficiency + # reasons) + + incumbents = self.get_incumbents(sort_by="num_trials") + if len(incumbents) == 0: + return float(self._scenario.runtime_cutoff) if self._scenario.runtime_cutoff is not None else float("inf") + + if len(incumbents) > 1: + warnings.warn("Adaptive capping is only supported for single incumbent scenarios") + + inc_sum_cost_unchecked = self.runhistory.sum_cost( + config=incumbents[0], + normalize=False, + ) + if isinstance(inc_sum_cost_unchecked, list): + raise UnsupportedError() + else: + inc_sum_cost: float = inc_sum_cost_unchecked + + # original logic for get_runs_for_config: + # https://github.com/automl/SMAC3/blob/f1d2aa2ea3b6ad4075550af69e3300f19411a5ea/smac/runhistory/runhistory.py#L772 + if incumbents[0] == challenger: + # fixme not multi-objective ready + # initially when the queue is empty, the incumbent is intensified, then the cutoff + # must be the runtime minus the already spent budget on the incumbent across + # instances. + if self._scenario.runtime_cutoff is not None: + cutoff = self._scenario.runtime_cutoff - inc_sum_cost + else: + cutoff = float("inf") + if cutoff < 0: + warnings.warn(f"Proposed cutoff for the incumbent is negative: {cutoff}. " f"Setting cutoff to 0.") + cutoff = 0 + else: + # get all runs of the challenger + chall_inst_seeds = self.runhistory.get_instance_seed_budget_keys(challenger) + + # filtered incumbent cost; i.e. only the runtime of the subset of those instance the + # challenger will be racing on (before moving to the next subset of instances 2**N). + inc_id = self.runhistory.get_config_id(incumbents[0]) + inc_isb = self.runhistory.get_instance_seed_budget_keys(incumbents[0]) + + # FIXME: on_keys will only have the current instance-seed pair not all of the + # ones of the instance subset, the current challenger is allowed to run on + combined_list: list[Union[InstanceSeedKey, InstanceSeedBudgetKey]] = [*on_keys, *chall_inst_seeds] + current_inc_isb = [ + key for key in inc_isb if any(k.instance == key.instance and k.seed == key.seed for k in combined_list) + ] + + # Instance grouped costs over seeds + instance_costs = defaultdict(list) + for key in current_inc_isb: + k = TrialKey(config_id=inc_id, instance=key.instance, seed=key.seed, budget=key.budget) + instance_costs[key.instance].append(self.runhistory._data[k].cost) + + if any(len(costs) > 1 for costs in instance_costs.values()): + warnings.warn( + "The incumbent has been seen on multiple seeds per instance. " + "For adaptive capping, the cost will be calculated by the average cost " + "over seeds. Should the challenger be evaluated on multiple seeds, " + "we would need to imagine it hadn't for calculating the used budget so far!" + "This is not supported yet." + ) + # Calculate mean cost for each instance across seeds and sum them up to determine the total incumbent cost + inc_sum_cost = 0 + for costs in instance_costs.values(): + if not isinstance(costs[0], list): + average_instance_cost = np.array(costs).mean() + inc_sum_cost += average_instance_cost + else: + raise UnsupportedError() + + # compute the already used runtime for the challenger across instances + # FIXME: in the case of multiple seeds per instance, we need to imagine + # that the instance we want to allocate budget for is not yet evaluated on other + # seeds! + chal_sum_cost = self.runhistory.sum_cost( + config=challenger, + # fixme: chall_inst_seeds needs to be List[InstanceSeedBudgetKey] + instance_seed_budget_keys=chall_inst_seeds, + ) + assert type(chal_sum_cost) == float + + if self._scenario.runtime_cutoff is not None: + if self._scenario.adaptive_capping_slackfactor is not None: + cutoff = min( + self._scenario.runtime_cutoff, + inc_sum_cost * self._scenario.adaptive_capping_slackfactor - chal_sum_cost, + ) + else: + cutoff = min(self._scenario.runtime_cutoff, inc_sum_cost - chal_sum_cost) + else: + cutoff = inc_sum_cost - chal_sum_cost + + return cutoff diff --git a/smac/intensifier/successive_halving.py b/smac/intensifier/successive_halving.py index 546a27377..e7f2e8af5 100644 --- a/smac/intensifier/successive_halving.py +++ b/smac/intensifier/successive_halving.py @@ -248,6 +248,11 @@ def uses_instances(self) -> bool: # noqa: D102 return True + @property + def uses_cutoffs(self) -> bool: + """If the intensifier needs to make use of cutoffs.""" + return False + def print_tracker(self) -> None: """Prints the number of configurations in each bracket/stage.""" messages = [] diff --git a/smac/runhistory/dataclasses.py b/smac/runhistory/dataclasses.py index 8b2b34372..78f5e1fc0 100644 --- a/smac/runhistory/dataclasses.py +++ b/smac/runhistory/dataclasses.py @@ -125,12 +125,14 @@ class TrialInfo: instance : str | None, defaults to None seed : int | None, defaults to None budget : float | None, defaults to None + additional_info: dict[str, Any], defaults to {} """ config: Configuration instance: str | None = None seed: int | None = None budget: float | None = None + additional_info: dict[str, Any] = field(default_factory=dict) def get_instance_seed_key(self) -> InstanceSeedKey: """Instantiates and returns an InstanceSeedKey object""" diff --git a/smac/runhistory/encoder/abstract_encoder.py b/smac/runhistory/encoder/abstract_encoder.py index 94a5f4bf6..0c5cdfb76 100644 --- a/smac/runhistory/encoder/abstract_encoder.py +++ b/smac/runhistory/encoder/abstract_encoder.py @@ -72,7 +72,7 @@ def __init__( if self._instances is not None and self._n_features == 0: logger.warning( - "We strongly encourage to use instance features when using instances.", + "We strongly encourage to use instance features when using instances." "If no instance features are passed, the runhistory encoder can not distinguish between different " "instances and therefore returns the same data points with different values, all of which are " "used to train the surrogate model.\n" diff --git a/smac/runner/abstract_runner.py b/smac/runner/abstract_runner.py index 1059b5764..a3a2616ca 100644 --- a/smac/runner/abstract_runner.py +++ b/smac/runner/abstract_runner.py @@ -112,6 +112,7 @@ def run_wrapper( instance=trial_info.instance, budget=trial_info.budget, seed=trial_info.seed, + additional_info=trial_info.additional_info, **dask_data_to_scatter, ) except Exception as e: @@ -187,7 +188,8 @@ def run( instance: str | None = None, budget: float | None = None, seed: int | None = None, - ) -> tuple[StatusType, float | list[float], float, float, dict]: + additional_info: dict[str, Any] | None = None, + ) -> tuple[StatusType, float | list[float], float, float, dict[str, Any]]: # noqa: D102 """Runs the target function with a configuration on a single instance-budget-seed combination (aka trial). @@ -201,6 +203,9 @@ def run( A positive, real-valued number representing an arbitrary limit to the target function handled by the target function internally. seed : int, defaults to None + Seed for the random number generator. + additional_info : dict + Further additional trial information. Returns ------- @@ -212,8 +217,8 @@ def run( The time the target function took to run. cpu_time : float The time the target function took on hardware to run. - additional_info : dict - All further additional trial information. + val_additional_info : dict + All further additional trial value information. """ raise NotImplementedError diff --git a/smac/runner/dask_runner.py b/smac/runner/dask_runner.py index 904eb5942..438b20d25 100644 --- a/smac/runner/dask_runner.py +++ b/smac/runner/dask_runner.py @@ -162,10 +162,16 @@ def run( instance: str | None = None, budget: float | None = None, seed: int | None = None, + additional_info: dict[str, Any] | None = None, **dask_data_to_scatter: dict[str, Any], - ) -> tuple[StatusType, float | list[float], float, float, dict]: # noqa: D102 + ) -> tuple[StatusType, float | list[float], float, float, dict[str, Any]]: # noqa: D102 return self._single_worker.run( - config=config, instance=instance, seed=seed, budget=budget, **dask_data_to_scatter + config=config, + instance=instance, + seed=seed, + budget=budget, + additional_info=additional_info, + **dask_data_to_scatter, ) def count_available_workers(self) -> int: diff --git a/smac/runner/target_function_runner.py b/smac/runner/target_function_runner.py index 35dbc2d1b..c26924799 100644 --- a/smac/runner/target_function_runner.py +++ b/smac/runner/target_function_runner.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Callable +from typing import Any, Callable, Optional, Union import copy import inspect @@ -108,11 +108,12 @@ def meta(self) -> dict[str, Any]: # noqa: D102 def run( self, config: Configuration, - instance: str | None = None, - budget: float | None = None, - seed: int | None = None, + instance: Optional[str] = None, + budget: Optional[float] = None, + seed: Optional[int] = None, + additional_info: Optional[dict[str, Any]] = None, **dask_data_to_scatter: dict[str, Any], - ) -> tuple[StatusType, float | list[float], float, float, dict]: + ) -> tuple[StatusType, Union[float, list[float]], float, float, dict[str, Any]]: """Calls the target function with pynisher if algorithm wall time limit or memory limit is set. Otherwise, the function is called directly. @@ -126,6 +127,8 @@ def run( A positive, real-valued number representing an arbitrary limit to the target function handled by the target function internally. seed : int, defaults to None + additional_info : dict[str, Any] | None, defaults to None + Additional information to be passed to the target function. dask_data_to_scatter: dict[str, Any] This kwargs must be empty when we do not use dask! () When a user scatters data from their local process to the distributed network, @@ -145,7 +148,7 @@ def run( The time the target function took to run. cpu_time : float The time the target function took on the hardware to run. - additional_info : dict + val_additional_info : dict All further additional trial information. """ # The kwargs are passed to the target function. @@ -161,11 +164,14 @@ def run( if "budget" in self._required_arguments: kwargs["budget"] = budget + if "cutoff" in self._required_arguments and additional_info is not None: + kwargs["cutoff"] = additional_info["cutoff"] + # Presetting cost: float | list[float] = self._crash_cost runtime = 0.0 cpu_time = runtime - additional_info = {} + val_additional_info = {} status = StatusType.CRASHED # If memory limit or walltime limit is set, we wanna use pynisher @@ -197,19 +203,19 @@ def run( status = StatusType.MEMORYOUT except Exception as e: cost = np.asarray(cost).squeeze().tolist() - additional_info = { + val_additional_info = { "traceback": traceback.format_exc(), "error": repr(e), } status = StatusType.CRASHED if status != StatusType.SUCCESS: - return status, cost, runtime, cpu_time, additional_info + return status, cost, runtime, cpu_time, val_additional_info if isinstance(rval, tuple): - result, additional_info = rval + result, val_additional_info = rval else: - result, additional_info = rval, {} + result, val_additional_info = rval, {} # Do some sanity checking (for multi objective) error = f"Returned costs {result} does not match the number of objectives {self._objectives}." @@ -245,7 +251,7 @@ def run( # We want to get either a float or a list of floats. cost = np.asarray(cost).squeeze().tolist() - return status, cost, runtime, cpu_time, additional_info + return status, cost, runtime, cpu_time, val_additional_info def __call__( self, diff --git a/smac/runner/target_function_script_runner.py b/smac/runner/target_function_script_runner.py index 7eb8b5f62..d395b4f86 100644 --- a/smac/runner/target_function_script_runner.py +++ b/smac/runner/target_function_script_runner.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, Optional, Union import time from subprocess import PIPE, Popen @@ -80,10 +80,11 @@ def meta(self) -> dict[str, Any]: # noqa: D102 def run( self, config: Configuration, - instance: str | None = None, - budget: float | None = None, - seed: int | None = None, - ) -> tuple[StatusType, float | list[float], float, float, dict]: + instance: Optional[str] = None, + budget: Optional[float] = None, + seed: Optional[int] = None, + additional_info: Optional[dict[str, Any]] = None, + ) -> tuple[StatusType, Union[float, list[float]], float, float, dict[str, Any]]: """Calls the target function. Parameters @@ -96,6 +97,8 @@ def run( A positive, real-valued number representing an arbitrary limit to the target function handled by the target function internally. seed : int, defaults to None + additional_info : dict + All further additional trial information. Returns ------- @@ -107,7 +110,7 @@ def run( The time the target function took to run. cpu_time : float The time the target function took on the hardware to run. - additional_info : dict + val_additional_info : dict All further additional trial information. """ # The kwargs are passed to the target function. @@ -127,11 +130,14 @@ def run( if "budget" in self._required_arguments: kwargs["budget"] = budget + if "cutoff" in self._required_arguments and additional_info is not None: + kwargs["cutoff"] = additional_info["cutoff"] + # Presetting cost: float | list[float] = self._crash_cost runtime = 0.0 cpu_time = runtime - additional_info = {} + val_additional_info = {} status = StatusType.SUCCESS # Add config arguments to the kwargs @@ -192,10 +198,10 @@ def run( # Add additional info if "additional_info" in outputs: - additional_info["additional_info"] = outputs["additional_info"] + val_additional_info["additional_info"] = outputs["additional_info"] if status != StatusType.SUCCESS: - additional_info["error"] = error + val_additional_info["error"] = error if cost != self._crash_cost: cost = self._crash_cost @@ -203,7 +209,7 @@ def run( "The target function crashed but returned a cost. The cost is ignored and replaced by crash cost." ) - return status, cost, runtime, cpu_time, additional_info + return status, cost, runtime, cpu_time, val_additional_info def __call__( self, diff --git a/smac/scenario.py b/smac/scenario.py index ca0df81a2..716f787ce 100644 --- a/smac/scenario.py +++ b/smac/scenario.py @@ -68,6 +68,13 @@ class Scenario: instance_features : dict[str, list[float]] | None, defaults to None Instances can be associated with features. For example, meta data of the dataset (mean, var, ...) can be incorporated which are then further used to expand the training data of the surrogate model. + adaptive_capping: bool, defaults to False + Adaptive capping allows to preemptiveley cancel the evaluation of candidates as soon as their accumulative cost + exceed the cost of the current incumbent. If a runtime_cutoff is set, SMAC will allocate budgets that are capped + at the runtime_cutoff. + runtime_cutoff: int | None, defaults to None + A runtime cutoff can be set to limit the maximum runtime allowed for a single configuration to run solving + instances. min_budget : float | int | None, defaults to None The minimum budget (epochs, subset size, number of instances, ...) that is used for the optimization. Use this argument if you use multi-fidelity or instance optimization. @@ -103,6 +110,9 @@ class Scenario: # Algorithm Configuration instances: list[str] | None = None instance_features: dict[str, list[float]] | None = None + adaptive_capping: bool = False + adaptive_capping_slackfactor: float | None = None + runtime_cutoff: int | None = None # Budgets min_budget: float | int | None = None @@ -129,6 +139,10 @@ def __post_init__(self) -> None: instance_features = {str(instance): features for instance, features in self.instance_features.items()} object.__setattr__(self, "instance_features", instance_features) + # Validate that we have a runtime cutoff set if adaptive capping slackfactor is given + if self.adaptive_capping_slackfactor is not None and self.runtime_cutoff is None: + raise ValueError("If adaptive_capping_slackfactor is set, then runtime_cutoff must be set as well.") + # Change directory wrt name and seed self._change_output_directory() diff --git a/tests/test_intensifier/test_adaptive_capping_intensifier.py b/tests/test_intensifier/test_adaptive_capping_intensifier.py new file mode 100644 index 000000000..28bd0be88 --- /dev/null +++ b/tests/test_intensifier/test_adaptive_capping_intensifier.py @@ -0,0 +1,136 @@ +import logging +import shutil +from logging import Logger + +from ConfigSpace import ConfigurationSpace, Configuration + +from smac import Scenario, AlgorithmConfigurationFacade +import numpy as np + +from smac.main.exceptions import ConfigurationSpaceExhaustedException +from smac.utils.configspace import get_config_hash +from smac.utils.logging import get_logger + +__copyright__ = "Copyright 2025, automl.org" +__license__ = "3-clause BSD" + +logger = get_logger(__name__) + +class TrainMockup: + def __init__(self): + # counter for referring to evaluation number + self.log_counter = 0 + + # log list and map for easier access of performance data + self.log_list = [] + self.log_map = {} + + # current incumbent configuration + hash + self.incumbent = None + + # trace of incumbents over time + self.incumbent_trace = [] + + # accept/reject event log + self.event_log = [] + + # list of rejected challengers + self.rejected_challengers = [] + + # flag validating whether expected behavior got violated + self.expected_behavior_violated = False + + # list of explanations what erroneous behavior got observed + self.expected_behavior_violations = [] + np.random.seed(42) + + def train(self, config:Configuration, instance: str, seed: int = 0): + self.log_counter += 1 + config_hash = get_config_hash(config) + rand_perf = np.random.random_integers(low=1, high=20) + + # check whether config needed to be rejected already + if config_hash in self.rejected_challengers: + self.expected_behavior_violated = True + + # search for first rejection in event log + reject_log = None + for log in self.event_log: + if log[1] == "reject" and log[2] == config_hash: + reject_log = log + break + + self.expected_behavior_violations += [f"{self.log_counter}: Configuration {config_hash} was already " + f"rejected here {reject_log}"] + + # specify log entry (id, config hash, instance, performance) + log = (self.log_counter, config_hash, instance, rand_perf) + + # ensure config has an entry in the log map and store observed performance, add log entry to list + if config_hash not in self.log_map: + self.log_map[config_hash] = {} + self.log_map[config_hash][instance] = rand_perf + self.log_list += [log] + logger.debug(f"Train: {log}") + + # if incumbent is none so far, we have a new incumbent, i.e., the initial one + if self.incumbent is None: + self.incumbent = config_hash + logger.debug(f"Incumbent initially set to config {self.incumbent}") + else: + instances_evaluated = self.log_map[config_hash].keys() + n_challenger = len(instances_evaluated) + n_incumbent = len(self.log_map[self.incumbent]) + c_challenger = np.array([self.log_map[config_hash][instance] for instance in instances_evaluated]).sum() + c_incumbent = np.array([self.log_map[self.incumbent][instance] for instance in instances_evaluated]).sum() + + logger.debug(f"evaluated challenger {config_hash} on budget {n_challenger} with performance " + f"{c_challenger} and incumbent {self.incumbent} was evaluated on budget {n_incumbent} " + f"showing performance {c_incumbent}.") + + if n_challenger >= n_incumbent and c_challenger < c_incumbent: + self.event_log += [(self.log_counter, "accept", config_hash)] + self.incumbent = config_hash + elif c_challenger > c_incumbent: + self.rejected_challengers += [config_hash] + self.event_log += [(self.log_counter, "reject", config_hash)] + + return rand_perf + + def get_violation_report(self): + report = "The following violations occurred:\n" + report += "\n".join(self.expected_behavior_violations) + return report + +def get_basic_setup(train, num_configs = 10, num_instances = 10, num_trials=30): + # generate config space with num_configs many different configurations + cs = ConfigurationSpace({"p1": ["v"+str(i) for i in range(num_configs)], }) + cs.seed(42) + # generate instance set with num_instances many instances + instances = ["i"+str(i) for i in range(num_instances)] + # setup scenario with generated config space, instances, and the given number of trials + scenario = Scenario(cs, deterministic=True, n_trials=num_trials, instances=instances, seed=44) + return AlgorithmConfigurationFacade(scenario, train) + +def test_incumbent_switch(): + # remove smac3 output folder to ensure proper execution of the test + shutil.rmtree("./smac3_output", ignore_errors=True) + shutil.rmtree("./smac3_output_test", ignore_errors=True) + + # setup test environment + tm = TrainMockup() + smac = get_basic_setup(tm.train) + + # activate logging + l: Logger = get_logger("smac.intensifier.abstract_intensifier") + l.setLevel(5) + l: Logger = get_logger("smac.intensifier.intensifier") + l.setLevel(10) + + # start smac run + try: + smac.optimize() + except ConfigurationSpaceExhaustedException: + pass + + assert tm.expected_behavior_violated is False, tm.get_violation_report() \ No newline at end of file diff --git a/tests/test_intensifier/test_intensifier.py b/tests/test_intensifier/test_intensifier.py index aebd96faf..2a5f5a955 100644 --- a/tests/test_intensifier/test_intensifier.py +++ b/tests/test_intensifier/test_intensifier.py @@ -1,6 +1,7 @@ from smac.initial_design.random_design import RandomInitialDesign from smac.intensifier.intensifier import Intensifier from smac.main.config_selector import ConfigSelector +from smac.runhistory import TrialInfo, TrialKey, TrialValue from smac.runhistory.enumerations import StatusType from smac.runhistory.runhistory import RunHistory from smac.scenario import Scenario @@ -73,7 +74,8 @@ def test_next_trials(make_scenario, configspace_small, make_config_selector): # Next, we want to check if evaluated trials are removed config = configspace_small.get_default_configuration() - runhistory.add(config=config, cost=0.5, time=0.0, instance=trials[0].instance, seed=trials[0].seed) + runhistory.add(config=config, cost=0.5, time=0.0, instance=trials[0].instance, + seed=trials[0].seed) trials = intensifier._get_next_trials(config) assert len(trials) == 8 @@ -167,7 +169,8 @@ def test_intensifier(make_scenario, configspace_small, make_config_selector): # Let's mark the first trial as finished # The config should become an incumbent now. - runhistory.add(config=trial.config, cost=10, time=0.0, instance=trial.instance, seed=trial.seed, force_update=True) + runhistory.add(config=trial.config, cost=10, time=0.0, instance=trial.instance, seed=trial.seed, + force_update=True) intensifier.update_incumbents(trial.config) assert intensifier.get_incumbent() == trial.config