Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pareto optimization #475

Open
wants to merge 47 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
bdd1e90
Improve deprecation warning message
AdrianSosic Dec 2, 2024
5e52039
Draft ParetoObjective class
AdrianSosic Dec 2, 2024
460bbc1
Extract function for transforming target columns
AdrianSosic Dec 2, 2024
04056ee
Add qLogNEHVI acqusition function
AdrianSosic Dec 5, 2024
8b2d199
Make botorch multiobjective acqusition functions autodetectable
AdrianSosic Dec 5, 2024
6788abb
Add temporary restriction allowing only MAX targets
AdrianSosic Dec 5, 2024
1043156
Draft example
AdrianSosic Dec 5, 2024
473ac15
Enable minimization targets
AdrianSosic Jan 24, 2025
41542f1
Add highlighted feature to README
AdrianSosic Jan 24, 2025
d0d7716
Update CHANGELOG.md
AdrianSosic Jan 24, 2025
75be534
Compute default reference point from data
AdrianSosic Feb 3, 2025
0e340e5
Flip signs of custom reference points in MIN mode
AdrianSosic Feb 3, 2025
12eef75
Interpolate target paretor frontier along transformed points
AdrianSosic Feb 3, 2025
0f29bda
Drop unnecessary label arguments
AdrianSosic Feb 3, 2025
1081ce7
Drop square root from target functions
AdrianSosic Feb 3, 2025
262392f
Mention ParetoObjective in README
AdrianSosic Feb 3, 2025
d9d8541
Fix enum comparison operator
AdrianSosic Feb 12, 2025
42273de
Drop duplicate override decorator
AdrianSosic Feb 12, 2025
44b605d
Fix random seed utility import
AdrianSosic Feb 12, 2025
b9bc60c
Explicitly convert targets to objectives
AdrianSosic Feb 12, 2025
7b97131
Dynamically select default acquisition function
AdrianSosic Feb 13, 2025
eb6aad6
Deactivate comparison for non-persistent attributes
AdrianSosic Feb 13, 2025
9cf8363
Fix variable reference in example
AdrianSosic Feb 13, 2025
5795a4a
Turn assert statement into proper exception
AdrianSosic Feb 13, 2025
f4384bd
Add prune_baseline attribute
AdrianSosic Feb 13, 2025
1fd25e9
Add full docstring to compute_ref_point
AdrianSosic Feb 13, 2025
0c605b8
Refactor ref_point computation logic
AdrianSosic Feb 13, 2025
81bf88c
Let doc generation append regular image when available
AdrianSosic Feb 13, 2025
f3e5c57
Add ParetoObjective user guide section
AdrianSosic Feb 13, 2025
3b41310
Add surrogate broadcasting mechanism
AdrianSosic Feb 14, 2025
3deee1b
Validate multi-target compatibility
AdrianSosic Feb 15, 2025
0252ef0
Rename broadcasting.py to composite.py
AdrianSosic Feb 15, 2025
26c7e36
Add CompositeSurrogate class
AdrianSosic Feb 17, 2025
d22bdac
Add surrogate composition test
AdrianSosic Feb 17, 2025
3e867d2
Add TODO note
AdrianSosic Feb 17, 2025
966cfd7
Add qLogNoisyExpectedHypervolumeImprovement to strategy
AdrianSosic Feb 17, 2025
43872a2
Add missing strategy arguments
AdrianSosic Feb 17, 2025
c157d01
Add pareto_objectives strategy and serialization test
AdrianSosic Feb 17, 2025
3383fd9
Fix default acquisition function mechanism
AdrianSosic Feb 17, 2025
5918719
Throw exception when using single-target acqf in multi-target context
AdrianSosic Feb 17, 2025
727ac31
Use specific incompatibility errors instead of generic ValueError
AdrianSosic Feb 17, 2025
d124ffd
Drop opinionated statement from user guide
AdrianSosic Feb 17, 2025
78fee0c
Mention requirement of multi-target acquisition function in user guide
AdrianSosic Feb 17, 2025
b3b77fb
Update CHANGELOG.md
AdrianSosic Feb 17, 2025
dc3d601
Rename supports_multi_target to support_multi_output
AdrianSosic Feb 18, 2025
f8b6b61
Fix Liskov
AdrianSosic Feb 18, 2025
ec9570f
Ignore typing problem in classproperty
AdrianSosic Feb 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

detached comment 3: working with Pareto I think the Campaign.posterior call is much more important and probaly part fo the workflow (to check the target predictions and possibly make a subchoice of poitns to evaluate on the predicted frontier). Imo this then should be mentioned both in the UG and in the example (albeit briefy, jsut referencing the posterior method and explaiing why its useful for that)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Followup: for the API as well as less experienced users, it might be useful to have a posterior method or any other convenience object that returns a dataframe with targets, mean and var of posterior prediction, and not a posterior object

Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `BCUT2D` encoding for `SubstanceParameter`
- Stored benchmarking results now include the Python environment and version
- `qPSTD` acquisition function
- `ParetoObjective` class for Pareto optimization of multiple targets and corresponding
`qLogNoisyExpectedHypervolumeImprovement` acquisition function
- `BroadcastingSurrogate` class and corresponding `Surrogate.broadcast` method for
making single-target surrogate models multi-target compatible
- `CompositeSurrogate` class for composing multi-target surrogates from single-target
surrogates
- `supports_multi_output` attribute/property to `Surrogate`/`AcquisitionFunction`

### Changed
- Acquisition function indicator `is_mc` has been removed in favor of new indicators
Expand Down
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ The following provides a non-comprehensive overview:

- 🛠️ Custom parameter encodings: Improve your campaign with domain knowledge
- 🧪 Built-in chemical encodings: Improve your campaign with chemical knowledge
- 🎯 Single and multiple targets with min, max and match objectives
- 🎯 Numerical and binary targets with min, max and match objectives
- ⚖️ Multi-target support via Pareto optimization and desirability scalarization
- 🔍 Insights: Easily analyze feature importance and model behavior
- 🎭 Hybrid (mixed continuous and discrete) spaces
- 🚀 Transfer learning: Mix data from multiple campaigns and accelerate optimization
Expand Down Expand Up @@ -78,8 +79,8 @@ target = NumericalTarget(
objective = SingleTargetObjective(target=target)
```
In cases where we are confronted with multiple (potentially conflicting) targets,
the `DesirabilityObjective` can be used instead. It allows to define additional
settings, such as how these targets should be balanced.
the `ParetoObjective` or `DesirabilityObjective` can be used instead.
These allow to define additional settings, such as how the targets should be balanced.
For more details, see the
[objectives section](https://emdgroup.github.io/baybe/stable/userguide/objectives.html)
of the user guide.
Expand Down
6 changes: 6 additions & 0 deletions baybe/acquisition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
qExpectedImprovement,
qKnowledgeGradient,
qLogExpectedImprovement,
qLogNoisyExpectedHypervolumeImprovement,
qLogNoisyExpectedImprovement,
qNegIntegratedPosteriorVariance,
qNoisyExpectedImprovement,
Expand Down Expand Up @@ -37,6 +38,7 @@
UCB = UpperConfidenceBound
qUCB = qUpperConfidenceBound
qTS = qThompsonSampling
qLogNEHVI = qLogNoisyExpectedHypervolumeImprovement

__all__ = [
######################### Acquisition functions
Expand Down Expand Up @@ -64,6 +66,8 @@
"qUpperConfidenceBound",
# Thompson Sampling
"qThompsonSampling",
# Hypervolume Improvement
"qLogNoisyExpectedHypervolumeImprovement",
######################### Abbreviations
# Knowledge Gradient
"qKG",
Expand All @@ -89,4 +93,6 @@
"qUCB",
# Thompson Sampling
"qTS",
# Hypervolume Improvement
"qLogNEHVI",
]
81 changes: 80 additions & 1 deletion baybe/acquisition/acqfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import math
from typing import ClassVar

import numpy as np
import numpy.typing as npt
import pandas as pd
from attr.converters import optional as optional_c
from attr.validators import optional as optional_v
Expand All @@ -13,7 +15,7 @@

from baybe.acquisition.base import AcquisitionFunction
from baybe.searchspace import SearchSpace
from baybe.utils.basic import classproperty
from baybe.utils.basic import classproperty, convert_to_float
from baybe.utils.sampling_algorithms import (
DiscreteSamplingMethod,
sample_numerical_df,
Expand Down Expand Up @@ -320,5 +322,82 @@ def supports_batching(cls) -> bool:
return False


########################################################################################
### Hypervolume Improvement
@define(frozen=True)
class qLogNoisyExpectedHypervolumeImprovement(AcquisitionFunction):
"""Logarithmic Monte Carlo based noisy expected hypervolume improvement."""

abbreviation: ClassVar[str] = "qLogNEHVI"

ref_point: float | tuple[float, ...] | None = field(
default=None, converter=optional_c(convert_to_float)
)
"""The reference point for computing the hypervolume improvement.

* When omitted, a default reference point is computed based on the provided data.
* When specified as a float, the value is interpreted as a multiplicative factor
determining the reference point location based on the difference between the best
and worst target configuration in the provided data.
* When specified as a vector, the input is taken as is.
"""

prune_baseline: bool = field(default=True, validator=instance_of(bool))
"""Auto-prune candidates that are unlikely to be the best."""

@staticmethod
def compute_ref_point(
array: npt.ArrayLike, maximize: npt.ArrayLike, factor: float = 0.1
) -> np.ndarray:
"""Compute a reference point for a given set of of target configurations.

The reference point is positioned in relation to the worst target configuration
within the provided array. The distance in each target dimension is adjusted by
a specified multiplication factor, which scales the reference point away from
the worst target configuration based on the maximum observed differences in
target values.
Comment on lines +354 to +358
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The reference point is positioned in relation to the worst target configuration
within the provided array. The distance in each target dimension is adjusted by
a specified multiplication factor, which scales the reference point away from
the worst target configuration based on the maximum observed differences in
target values.
The reference point is positioned relative to the worst point in the direction coming from the best
point. A factor of 0.0 would result in the reference point being the worst point, while a factor > 0.0
would move the reference point further away from both worst and best points. A factor of 1.0 would
exactly mirror the best on on the worst point.


Example:
>>> from baybe.acquisition import qLogNEHVI

>>> qLogNEHVI.compute_ref_point([[0, 10], [2, 20]], [True, True], 0.1)
array([-0.2, 9. ])

>>> qLogNEHVI.compute_ref_point([[0, 10], [2, 20]], [True, False], 0.2)
array([ -0.4, 22. ])

Args:
array: A 2-D array-like where each row represents a target configuration.
maximize: A 1-D boolean array indicating which targets are to be maximized.
factor: A numeric value controlling the location of the reference point.

Raises:
ValueError: If the given target configuration array is not two-dimensional.
ValueError: If the given Boolean array is not one-dimensional.

Returns:
The computed reference point.
"""
if np.ndim(array) != 2:
raise ValueError(
"The specified data array must have exactly two dimensions."
)
if np.ndim(maximize) != 1:
raise ValueError(
"The specified Boolean array must have exactly one dimension."
)

# Convert arrays
array = np.asarray(array)
maximize = np.where(maximize, 1.0, -1.0)

# Compute bounds
array = array * maximize[None, :]
min = np.min(array, axis=0)
max = np.max(array, axis=0)

return (min - factor * (max - min)) * maximize


# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
50 changes: 48 additions & 2 deletions baybe/acquisition/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import gc
import warnings
from abc import ABC
from collections.abc import Iterable
from inspect import signature
from typing import TYPE_CHECKING, ClassVar

Expand All @@ -17,6 +18,7 @@
)
from baybe.objectives.base import Objective
from baybe.objectives.desirability import DesirabilityObjective
from baybe.objectives.pareto import ParetoObjective
from baybe.objectives.single import SingleTargetObjective
from baybe.searchspace.core import SearchSpace
from baybe.serialization.core import (
Expand Down Expand Up @@ -56,6 +58,11 @@ def supports_pending_experiments(cls) -> bool:
"""
return cls.supports_batching

@classproperty
def supports_multi_output(cls) -> bool:
"""Flag indicating whether multiple outputs are supported."""
return "Hypervolume" in cls.__name__ # type: ignore[attr-defined]

@classproperty
def _non_botorch_attrs(cls) -> tuple[str, ...]:
"""Names of attributes that are not passed to the BoTorch constructor."""
Expand All @@ -76,9 +83,13 @@ def to_botorch(
"""
import botorch.acquisition as bo_acqf
import torch
from botorch.acquisition.multi_objective import WeightedMCMultiOutputObjective
from botorch.acquisition.objective import LinearMCObjective

from baybe.acquisition.acqfs import qThompsonSampling
from baybe.acquisition.acqfs import (
qLogNoisyExpectedHypervolumeImprovement,
qThompsonSampling,
)

# Retrieve botorch acquisition function class and match attributes
acqf_cls = _get_botorch_acqf_class(type(self))
Expand Down Expand Up @@ -151,6 +162,39 @@ def to_botorch(
additional_params["best_f"] = (
bo_surrogate.posterior(train_x).mean.max().item()
)
case ParetoObjective():
if not isinstance(self, qLogNoisyExpectedHypervolumeImprovement):
raise IncompatibleAcquisitionFunctionError(
f"Pareto optimization currently supports the "
f"'{qLogNoisyExpectedHypervolumeImprovement.__name__}' "
f"acquisition function only."
)
if not all(
isinstance(t, NumericalTarget)
and t.mode in (TargetMode.MAX, TargetMode.MIN)
for t in objective.targets
):
raise NotImplementedError(
"Pareto optimization currently supports "
"maximization/minimization targets only."
)
maximize = [t.mode is TargetMode.MAX for t in objective.targets] # type: ignore[attr-defined]
multiplier = [1.0 if m else -1.0 for m in maximize]
additional_params["objective"] = WeightedMCMultiOutputObjective(
torch.tensor(multiplier)
)
train_y = measurements[[t.name for t in objective.targets]].to_numpy()
if isinstance(ref_point := params_dict["ref_point"], Iterable):
ref_point = [
p * m for p, m in zip(ref_point, multiplier, strict=True)
]
else:
kwargs = {"factor": ref_point} if ref_point is not None else {}
ref_point = (
self.compute_ref_point(train_y, maximize, **kwargs) * multiplier
)
params_dict["ref_point"] = ref_point

case _:
raise ValueError(f"Unsupported objective type: {objective}")

Expand All @@ -172,7 +216,9 @@ def _get_botorch_acqf_class(
import botorch

for cls in baybe_acqf_cls.mro():
if acqf_cls := getattr(botorch.acquisition, cls.__name__, False):
if acqf_cls := getattr(botorch.acquisition, cls.__name__, False) or getattr(
botorch.acquisition.multi_objective, cls.__name__, False
):
if is_abstract(acqf_cls):
continue
return acqf_cls # type: ignore
Expand Down
4 changes: 4 additions & 0 deletions baybe/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ class IncompatibleSearchSpaceError(IncompatibilityError):
"""


class IncompatibleSurrogateError(IncompatibilityError):
"""An incompatible surrogate was selected."""


class IncompatibleAcquisitionFunctionError(IncompatibilityError):
"""An incompatible acquisition function was selected."""

Expand Down
2 changes: 2 additions & 0 deletions baybe/objectives/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""BayBE objectives."""

from baybe.objectives.desirability import DesirabilityObjective
from baybe.objectives.pareto import ParetoObjective
from baybe.objectives.single import SingleTargetObjective

__all__ = [
"SingleTargetObjective",
"DesirabilityObjective",
"ParetoObjective",
]
13 changes: 4 additions & 9 deletions baybe/objectives/desirability.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from baybe.targets.base import Target
from baybe.targets.numerical import NumericalTarget
from baybe.utils.basic import is_all_instance, to_tuple
from baybe.utils.dataframe import get_transform_objects, pretty_print_df
from baybe.utils.dataframe import pretty_print_df, transform_target_columns
from baybe.utils.numerical import geom_mean
from baybe.utils.plotting import to_string
from baybe.utils.validation import finite_float
Expand Down Expand Up @@ -145,7 +145,7 @@ def transform(
# >>>>>>>>>> Deprecation
if not ((df is None) ^ (data is None)):
raise ValueError(
"Provide the dataframe to be transformed as argument to `df`."
"Provide the dataframe to be transformed as first positional argument."
)

if data is not None:
Expand All @@ -172,15 +172,10 @@ def transform(
)
# <<<<<<<<<< Deprecation

# Extract the relevant part of the dataframe
targets = get_transform_objects(
# Transform all targets individually
transformed = transform_target_columns(
df, self.targets, allow_missing=allow_missing, allow_extra=allow_extra
)
transformed = df[[t.name for t in targets]].copy()

# Transform all targets individually
for target in self.targets:
transformed[target.name] = target.transform(df[target.name])

# Scalarize the transformed targets into desirability values
vals = scalarize(transformed.values, self.scalarizer, self._normalized_weights)
Expand Down
74 changes: 74 additions & 0 deletions baybe/objectives/pareto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Functionality for multi-target objectives."""

import warnings

import pandas as pd
from attrs import define, field
from attrs.validators import deep_iterable, instance_of, min_len
from typing_extensions import override

from baybe.objectives.base import Objective
from baybe.targets.base import Target
from baybe.utils.basic import to_tuple
from baybe.utils.dataframe import transform_target_columns


@define(frozen=True, slots=False)
class ParetoObjective(Objective):
"""An objective handling multiple targets in a Pareto sense."""

_targets: tuple[Target, ...] = field(
converter=to_tuple,
validator=[min_len(2), deep_iterable(member_validator=instance_of(Target))],
alias="targets",
)
"The targets considered by the objective."

@override
@property
def targets(self) -> tuple[Target, ...]:
return self._targets

@override
def transform(
self,
df: pd.DataFrame | None = None,
/,
*,
allow_missing: bool = False,
allow_extra: bool | None = None,
data: pd.DataFrame | None = None,
) -> pd.DataFrame:
# >>>>>>>>>> Deprecation
if not ((df is None) ^ (data is None)):
raise ValueError(
"Provide the dataframe to be transformed as first positional argument."
)

if data is not None:
df = data
warnings.warn(
"Providing the dataframe via the `data` argument is deprecated and "
"will be removed in a future version. Please pass your dataframe "
"as positional argument instead.",
DeprecationWarning,
)

# Mypy does not infer from the above that `df` must be a dataframe here
assert isinstance(df, pd.DataFrame)

if allow_extra is None:
allow_extra = True
if set(df.columns) - {p.name for p in self.targets}:
warnings.warn(
"For backward compatibility, the new `allow_extra` flag is set "
"to `True` when left unspecified. However, this behavior will be "
"changed in a future version. If you want to invoke the old "
"behavior, please explicitly set `allow_extra=True`.",
DeprecationWarning,
)
# <<<<<<<<<< Deprecation

return transform_target_columns(
df, self.targets, allow_missing=allow_missing, allow_extra=allow_extra
)
Loading
Loading