Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BoTorch kernel preset, which uses dimensions-scaled prior #483

Draft
wants to merge 40 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
09bfc9f
Add BoTorch kernel preset, which uses dimensions-scaled prior
Hrovatin Feb 11, 2025
45c6c49
pre-commit fixes
Hrovatin Feb 12, 2025
ba35bc2
add to changelog
Hrovatin Feb 12, 2025
40c1364
Explicitly close plot objects
Scienfitz Feb 12, 2025
8ff30ed
Extract experimental input validation utility
Scienfitz Jan 3, 2025
c757cdb
Fix simulation with empty initial data
Scienfitz Jan 3, 2025
596ac2f
Expand basic input output tests
Scienfitz Jan 3, 2025
61b1ddc
Add test for invalid pending_experiments
Scienfitz Jan 3, 2025
6d61972
Add pending_experiments validation
Scienfitz Jan 3, 2025
3169539
Fix docstring
Scienfitz Jan 6, 2025
91b3356
Add utility for creating fake input
Scienfitz Jan 6, 2025
0cd53dc
Add fixture for fake measurements
Scienfitz Jan 6, 2025
40ffdd5
Update type hints
Scienfitz Jan 6, 2025
7988f95
Improve text
Scienfitz Jan 6, 2025
1f13d46
Add note
Scienfitz Jan 10, 2025
a4adf50
Add validation everywhere
Scienfitz Feb 7, 2025
a5fcfd1
Avoid duplicated validation
Scienfitz Feb 7, 2025
8ba6ff3
Make wrapper private
Scienfitz Feb 10, 2025
266f6d5
Rework test parameterization
Scienfitz Feb 14, 2025
8ce2784
Improve code
Scienfitz Feb 14, 2025
04bdd99
Remove input validation in meta recommenders
Scienfitz Feb 14, 2025
7ba16b5
Include input validation in BayesianRecommender
Scienfitz Feb 14, 2025
a0e29ed
Reorder docstring
Scienfitz Feb 17, 2025
4ed31a0
Change logger to warnings.warn
Scienfitz Feb 17, 2025
6eda84d
Amend docstring
Scienfitz Feb 17, 2025
5f6f8d1
Remove obsolete parameter
Scienfitz Feb 17, 2025
ac7b287
Remove non-negativity restriction from beta
AdrianSosic Feb 17, 2025
f79cbfb
Rewrite docstring
AdrianSosic Feb 18, 2025
ad39268
Fix validator
Scienfitz Feb 18, 2025
b3d8bdc
Remove obsolete test
Scienfitz Feb 18, 2025
d83e6a2
Add direct arylation benchmark for TL with temperature as a task
Hrovatin Feb 19, 2025
013fec0
Update changelog
Hrovatin Feb 19, 2025
bc038a5
remove random seed that was set in the paper as it is redundant with …
Hrovatin Feb 19, 2025
90f0b49
Benchmark for transfer learning on arylhalides with dissimilar susbst…
Hrovatin Feb 20, 2025
15b35dc
Transfer learning benchmark with inverted Hartmann functions as tasks
Hrovatin Feb 20, 2025
0faec32
Add non-transfer learning campaign and transfer learning campaign wit…
Hrovatin Feb 20, 2025
e6905f9
Transfer learning benchmark with noisy Michalewicz functions as tasks
Hrovatin Feb 21, 2025
77b7d5c
Transfer learning benchmark with noisy Easom functions as tasks.
Hrovatin Feb 21, 2025
ef173ef
add to changelog
Hrovatin Feb 12, 2025
588d73c
Add a few botorch kernel preset benchmarks and adapt scripts for a te…
Hrovatin Feb 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `BCUT2D` encoding for `SubstanceParameter`
- Stored benchmarking results now include the Python environment and version
- `qPSTD` acquisition function
- BoTorch kernel presets.
- Additional benchmarks
- BoTorch kernel presets.

### Changed
- Acquisition function indicator `is_mc` has been removed in favor of new indicators
Expand All @@ -20,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.12.2] - 2025-01-31
### Changed
- More robust settings for the GP fitting
- The `beta` parameter of `UCB` and `qUCB` can now also take negative values

## [0.12.1] - 2025-01-29
### Changed
Expand Down
28 changes: 12 additions & 16 deletions baybe/acquisition/acqfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@
from attr.converters import optional as optional_c
from attr.validators import optional as optional_v
from attrs import define, field, fields
from attrs.validators import ge, gt, instance_of, le
from attrs.validators import gt, instance_of, le
from typing_extensions import override

from baybe.acquisition.base import AcquisitionFunction
from baybe.searchspace import SearchSpace
from baybe.utils.basic import classproperty
from baybe.utils.sampling_algorithms import (
DiscreteSamplingMethod,
sample_numerical_df,
)
from baybe.utils.sampling_algorithms import DiscreteSamplingMethod, sample_numerical_df
from baybe.utils.validation import finite_float


########################################################################################
Expand Down Expand Up @@ -264,12 +262,15 @@ class UpperConfidenceBound(AcquisitionFunction):

abbreviation: ClassVar[str] = "UCB"

beta: float = field(converter=float, validator=ge(0.0), default=0.2)
beta: float = field(converter=float, validator=finite_float, default=0.2)
"""Trade-off parameter for mean and variance.

A value of zero makes the acquisition mechanism consider the posterior predictive
mean only, resulting in pure exploitation. Higher values shift the focus more and
more toward exploration.
* ``beta > 0``: Rewards uncertainty, takes more risk.
Limit ``inf``: Pure exploration
* ``beta < 0``: Punishes uncertainty, takes less risk.
Limit ``-inf``: Pure exploitation
* ``beta = 0``: Discards knowledge about uncertainty, i.e. neither rewards nor
punishes it, is risk-neutral.
"""


Expand All @@ -279,13 +280,8 @@ class qUpperConfidenceBound(AcquisitionFunction):

abbreviation: ClassVar[str] = "qUCB"

beta: float = field(converter=float, validator=ge(0.0), default=0.2)
"""Trade-off parameter for mean and variance.

A value of zero makes the acquisition mechanism consider the posterior predictive
mean only, resulting in pure exploitation. Higher values shift the focus more and
more toward exploration.
"""
beta: float = field(converter=float, validator=finite_float, default=0.2)
"""See :paramref:`UpperConfidenceBound.beta`."""


@define(frozen=True)
Expand Down
58 changes: 16 additions & 42 deletions baybe/campaign.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@
)
from baybe.utils.basic import UNSPECIFIED, UnspecifiedType, is_all_instance
from baybe.utils.boolean import eq_dataframe
from baybe.utils.dataframe import filter_df, fuzzy_row_match
from baybe.utils.dataframe import _ValidatedDataFrame, filter_df, fuzzy_row_match
from baybe.utils.plotting import to_string
from baybe.utils.validation import validate_parameter_input, validate_target_input

if TYPE_CHECKING:
from botorch.posteriors import Posterior
Expand Down Expand Up @@ -264,48 +265,25 @@ def add_measurements(
Each addition of data is considered a new batch. Added results are checked for
validity. Categorical values need to have an exact match. For numerical values,
a campaign flag determines if values that lie outside a specified tolerance
are accepted.
Note that this modifies the provided data in-place.
are accepted. Possible validation exceptions are documented in
:func:`baybe.utils.validation.validate_target_input` and
:func:`baybe.utils.validation.validate_parameter_input`.

Args:
data: The data to be added (with filled values for targets). Preferably
created via :func:`baybe.campaign.Campaign.recommend`.
numerical_measurements_must_be_within_tolerance: Flag indicating if
numerical parameters need to be within their tolerances.

Raises:
ValueError: If one of the targets has missing values or NaNs in the provided
dataframe.
TypeError: If the target has non-numeric entries in the provided dataframe.
"""
# Invalidate recommendation cache first (in case of uncaught exceptions below)
self._cached_recommendation = pd.DataFrame()

# Check if all targets have valid values
for target in self.targets:
if data[target.name].isna().any():
raise ValueError(
f"The target '{target.name}' has missing values or NaNs in the "
f"provided dataframe. Missing target values are not supported."
)
if data[target.name].dtype.kind not in "iufb":
raise TypeError(
f"The target '{target.name}' has non-numeric entries in the "
f"provided dataframe. Non-numeric target values are not supported."
)

# Check if all targets have valid values
for param in self.parameters:
if data[param.name].isna().any():
raise ValueError(
f"The parameter '{param.name}' has missing values or NaNs in the "
f"provided dataframe. Missing parameter values are not supported."
)
if param.is_numerical and (data[param.name].dtype.kind not in "iufb"):
raise TypeError(
f"The numerical parameter '{param.name}' has non-numeric entries in"
f" the provided dataframe."
)
# Validate target and parameter input values
validate_target_input(data, self.targets)
validate_parameter_input(
data, self.parameters, numerical_measurements_must_be_within_tolerance
)
data.__class__ = _ValidatedDataFrame

# Read in measurements and add them to the database
self.n_batches_done += 1
Expand All @@ -320,20 +298,14 @@ def add_measurements(
# Update metadata
if self.searchspace.type in (SearchSpaceType.DISCRETE, SearchSpaceType.HYBRID):
idxs_matched = fuzzy_row_match(
self.searchspace.discrete.exp_rep,
data,
self.parameters,
numerical_measurements_must_be_within_tolerance,
self.searchspace.discrete.exp_rep, data, self.parameters
)
self._searchspace_metadata.loc[idxs_matched, _MEASURED] = True

# Telemetry
telemetry_record_value(TELEM_LABELS["COUNT_ADD_RESULTS"], 1)
telemetry_record_recommended_measurement_percentage(
self._cached_recommendation,
data,
self.parameters,
numerical_measurements_must_be_within_tolerance,
self._cached_recommendation, data, self.parameters
)

def toggle_discrete_candidates( # noqa: DOC501
Expand Down Expand Up @@ -423,8 +395,10 @@ def recommend(
)

# Invalidate cached recommendation if pending experiments are provided
if (pending_experiments is not None) and (len(pending_experiments) > 0):
if (pending_experiments is not None) and not pending_experiments.empty:
self._cached_recommendation = pd.DataFrame()
validate_parameter_input(pending_experiments, self.parameters)
pending_experiments.__class__ = _ValidatedDataFrame

# If there are cached recommendations and the batch size of those is equal to
# the previously requested one, we just return those
Expand Down
7 changes: 7 additions & 0 deletions baybe/kernels/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from attrs.converters import optional as optional_c
from attrs.validators import ge, gt, in_, instance_of
from attrs.validators import optional as optional_v
from gpytorch.constraints import Interval
from typing_extensions import override

from baybe.kernels.base import BasicKernel
Expand Down Expand Up @@ -180,6 +181,12 @@ class RBFKernel(BasicKernel):
)
"""An optional initial value for the kernel lengthscale."""

# TODO replace with baybe constraint if possible
lengthscale_constraint: Interval | None = field(
default=None, validator=optional_v(instance_of(Interval))
)
"""An optional prior on the kernel lengthscale constraint."""


@define(frozen=True)
class RFFKernel(BasicKernel):
Expand Down
21 changes: 21 additions & 0 deletions baybe/recommenders/pure/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from baybe.searchspace.continuous import SubspaceContinuous
from baybe.searchspace.core import SearchSpaceType
from baybe.searchspace.discrete import SubspaceDiscrete
from baybe.utils.dataframe import _ValidatedDataFrame
from baybe.utils.validation import validate_parameter_input, validate_target_input

_DEPRECATION_ERROR_MESSAGE = (
"The attribute '{}' is no longer available for recommenders. "
Expand Down Expand Up @@ -96,6 +98,25 @@ def recommend(
measurements: pd.DataFrame | None = None,
pending_experiments: pd.DataFrame | None = None,
) -> pd.DataFrame:
# Validation
if (
measurements is not None
and not isinstance(measurements, _ValidatedDataFrame)
and not measurements.empty
and objective is not None
and searchspace is not None
):
validate_target_input(measurements, objective.targets)
validate_parameter_input(measurements, searchspace.parameters)
measurements.__class__ = _ValidatedDataFrame
if (
pending_experiments is not None
and not isinstance(pending_experiments, _ValidatedDataFrame)
and searchspace is not None
):
validate_parameter_input(pending_experiments, searchspace.parameters)
pending_experiments.__class__ = _ValidatedDataFrame

if searchspace.type is SearchSpaceType.CONTINUOUS:
return self._recommend_continuous(
subspace_continuous=searchspace.continuous, batch_size=batch_size
Expand Down
14 changes: 13 additions & 1 deletion baybe/recommenders/pure/bayesian/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from baybe.searchspace import SearchSpace
from baybe.surrogates import CustomONNXSurrogate, GaussianProcessSurrogate
from baybe.surrogates.base import IndependentGaussianSurrogate, SurrogateProtocol
from baybe.utils.dataframe import _ValidatedDataFrame
from baybe.utils.validation import validate_parameter_input, validate_target_input


@define
Expand Down Expand Up @@ -104,11 +106,21 @@ def recommend(
f"that an objective is specified."
)

if (measurements is None) or (len(measurements) == 0):
# Experimental input validation
if (measurements is None) or measurements.empty:
raise NotImplementedError(
f"Recommenders of type '{BayesianRecommender.__name__}' do not support "
f"empty training data."
)
if not isinstance(measurements, _ValidatedDataFrame):
validate_target_input(measurements, objective.targets)
validate_parameter_input(measurements, searchspace.parameters)
measurements.__class__ = _ValidatedDataFrame
if pending_experiments is not None and not isinstance(
pending_experiments, _ValidatedDataFrame
):
validate_parameter_input(pending_experiments, searchspace.parameters)
pending_experiments.__class__ = _ValidatedDataFrame

if (
isinstance(self._surrogate_model, IndependentGaussianSurrogate)
Expand Down
2 changes: 1 addition & 1 deletion baybe/recommenders/pure/nonpredictive/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def recommend(
f"experiments from the candidate set, adjust the search space "
f"accordingly."
)
if (measurements is not None) and (len(measurements) != 0):
if (measurements is not None) and not measurements.empty:
warnings.warn(
f"'{self.recommend.__name__}' was called with a non-empty "
f"set of measurements but '{self.__class__.__name__}' does not "
Expand Down
2 changes: 1 addition & 1 deletion baybe/simulation/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def simulate_experiment(
campaign = deepcopy(campaign)

# Add the initial data
if initial_data is not None:
if (initial_data is not None) and not initial_data.empty:
campaign.add_measurements(initial_data)

# For impute_mode 'ignore', do not recommend space entries that are not
Expand Down
2 changes: 2 additions & 0 deletions baybe/surrogates/gaussian_process/presets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Gaussian process surrogate presets."""

from baybe.surrogates.gaussian_process.presets.botorch import BotorchKernelFactory
from baybe.surrogates.gaussian_process.presets.core import (
GaussianProcessPreset,
make_gp_from_preset,
Expand All @@ -10,6 +11,7 @@
__all__ = [
"DefaultKernelFactory",
"EDBOKernelFactory",
"BotorchKernelFactory",
"make_gp_from_preset",
"GaussianProcessPreset",
]
56 changes: 56 additions & 0 deletions baybe/surrogates/gaussian_process/presets/botorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Presets adapted from BoTorch."""

from __future__ import annotations

from math import log, sqrt
from typing import TYPE_CHECKING

from attrs import define
from gpytorch.constraints import GreaterThan
from typing_extensions import override

from baybe.kernels.basic import RBFKernel
from baybe.parameters import TaskParameter
from baybe.priors.basic import LogNormalPrior
from baybe.searchspace import SearchSpace
from baybe.surrogates.gaussian_process.kernel_factory import KernelFactory

if TYPE_CHECKING:
from torch import Tensor

from baybe.kernels.base import Kernel


@define
class BotorchKernelFactory(KernelFactory):
"""A kernel factory for Gaussian process surrogates adapted from BoTorch.

References:
* https://github.com/pytorch/botorch/blob/a018a5ffbcbface6229d6c39f7ac6ef9baf5765e/botorch/models/multitask.py#L220
* https://github.com/pytorch/botorch/blob/a018a5ffbcbface6229d6c39f7ac6ef9baf5765e/botorch/models/utils/gpytorch_modules.py#L100

"""

@override
def __call__(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
) -> Kernel:
ard_num_dims = train_x.shape[-1] - len(
[
param
for param in searchspace.discrete.parameters
if isinstance(param, TaskParameter)
]
)
lengthscale_prior = LogNormalPrior(
loc=sqrt(2) + log(ard_num_dims) * 0.5, scale=sqrt(3)
)

return RBFKernel(
lengthscale_prior=lengthscale_prior,
lengthscale_constraint=GreaterThan(
2.5e-2,
transform=None,
initial_value=lengthscale_prior.to_gpytorch().mode,
),
)
23 changes: 4 additions & 19 deletions baybe/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ def telemetry_record_recommended_measurement_percentage(
cached_recommendation: pd.DataFrame,
measurements: pd.DataFrame,
parameters: Sequence[Parameter],
numerical_measurements_must_be_within_tolerance: bool,
) -> None:
"""Submit the percentage of added measurements.

Expand All @@ -232,31 +231,17 @@ def telemetry_record_recommended_measurement_percentage(
measurements: The measurements which are supposed to be checked against cached
recommendations.
parameters: The list of parameters spanning the entire search space.
numerical_measurements_must_be_within_tolerance: If ``True``, numerical
parameter entries are matched with the reference elements only if there is
a match within the parameter tolerance. If ``False``, the closest match
is considered, irrespective of the distance.
"""
if is_enabled():
if len(cached_recommendation) > 0:
if cached_recommendation.empty:
_submit_scalar_value(TELEM_LABELS["NAKED_INITIAL_MEASUREMENTS"], 1)
else:
recommended_measurements_percentage = (
len(
fuzzy_row_match(
cached_recommendation,
measurements,
parameters,
numerical_measurements_must_be_within_tolerance,
)
)
len(fuzzy_row_match(cached_recommendation, measurements, parameters))
/ len(cached_recommendation)
* 100.0
)
_submit_scalar_value(
TELEM_LABELS["RECOMMENDED_MEASUREMENTS_PERCENTAGE"],
recommended_measurements_percentage,
)
else:
_submit_scalar_value(
TELEM_LABELS["NAKED_INITIAL_MEASUREMENTS"],
1,
)
Loading
Loading