Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Negative beta values #488

Merged
merged 4 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.12.2] - 2025-01-31
### Changed
- More robust settings for the GP fitting
- The `beta` parameter of `UCB` and `qUCB` can now also take negative values

## [0.12.1] - 2025-01-29
### Changed
Expand Down
28 changes: 12 additions & 16 deletions baybe/acquisition/acqfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@
from attr.converters import optional as optional_c
from attr.validators import optional as optional_v
from attrs import define, field, fields
from attrs.validators import ge, gt, instance_of, le
from attrs.validators import gt, instance_of, le
from typing_extensions import override

from baybe.acquisition.base import AcquisitionFunction
from baybe.searchspace import SearchSpace
from baybe.utils.basic import classproperty
from baybe.utils.sampling_algorithms import (
DiscreteSamplingMethod,
sample_numerical_df,
)
from baybe.utils.sampling_algorithms import DiscreteSamplingMethod, sample_numerical_df
from baybe.utils.validation import finite_float


########################################################################################
Expand Down Expand Up @@ -264,12 +262,15 @@ class UpperConfidenceBound(AcquisitionFunction):

abbreviation: ClassVar[str] = "UCB"

beta: float = field(converter=float, validator=ge(0.0), default=0.2)
beta: float = field(converter=float, validator=finite_float, default=0.2)
"""Trade-off parameter for mean and variance.
A value of zero makes the acquisition mechanism consider the posterior predictive
mean only, resulting in pure exploitation. Higher values shift the focus more and
more toward exploration.
* ``beta > 0``: Rewards uncertainty, takes more risk.
Limit ``inf``: Pure exploration
* ``beta < 0``: Punishes uncertainty, takes less risk.
Limit ``-inf``: Pure exploitation
* ``beta = 0``: Discards knowledge about uncertainty, i.e. neither rewards nor
punishes it, is risk-neutral.
"""


Expand All @@ -279,13 +280,8 @@ class qUpperConfidenceBound(AcquisitionFunction):

abbreviation: ClassVar[str] = "qUCB"

beta: float = field(converter=float, validator=ge(0.0), default=0.2)
"""Trade-off parameter for mean and variance.
A value of zero makes the acquisition mechanism consider the posterior predictive
mean only, resulting in pure exploitation. Higher values shift the focus more and
more toward exploration.
"""
beta: float = field(converter=float, validator=finite_float, default=0.2)
"""See :paramref:`UpperConfidenceBound.beta`."""


@define(frozen=True)
Expand Down
4 changes: 2 additions & 2 deletions tests/hypothesis_strategies/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ def _qNIPV_strategy(draw: st.DrawFn):
acquisition_functions = st.one_of(
st.builds(ExpectedImprovement),
st.builds(ProbabilityOfImprovement),
st.builds(UpperConfidenceBound, beta=finite_floats(min_value=0.0)),
st.builds(UpperConfidenceBound, beta=finite_floats()),
st.builds(PosteriorMean),
st.builds(PosteriorStandardDeviation, maximize=st.sampled_from([True, False])),
st.builds(qPosteriorStandardDeviation),
st.builds(LogExpectedImprovement),
st.builds(qExpectedImprovement),
st.builds(qProbabilityOfImprovement),
st.builds(qUpperConfidenceBound, beta=finite_floats(min_value=0.0)),
st.builds(qUpperConfidenceBound, beta=finite_floats()),
st.builds(qSimpleRegret),
st.builds(qLogExpectedImprovement),
st.builds(
Expand Down
5 changes: 2 additions & 3 deletions tests/validation/test_acqf_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,17 @@ def test_qNIPV(sampling_n_points, sampling_fraction, sampling_method, error, mat
qNIPV(**kwargs)


@pytest.mark.parametrize("acqf", [qNEI, qLogNEI])
@pytest.mark.parametrize("acqf", [qNEI, qLogNEI], ids=["qNEI", "qLogNEI"])
def test_EI(acqf):
"""Providing a non-Boolean argument to `prune_baseline` raises an error."""
with pytest.raises(TypeError):
acqf(1)


@pytest.mark.parametrize("acqf", [UCB, qUCB])
@pytest.mark.parametrize("acqf", [UCB, qUCB], ids=["UCB", "qUCB"])
@pytest.mark.parametrize(
("beta", "match"),
[
param(-1.0, "must be >= 0.0", id="negative"),
param("a", "could not convert", id="not_a_float"),
],
)
Expand Down