Skip to content

Commit

Permalink
Merge: Negative beta values (#488)
Browse files Browse the repository at this point in the history
This PR removes the unnecessary restriction that `beta` in `(q)UCB` can
only take non-negative values + refines the corresponding docstring.
  • Loading branch information
Scienfitz authored Feb 18, 2025
2 parents 949b897 + e2d89dd commit 613bd3a
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 21 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.12.2] - 2025-01-31
### Changed
- More robust settings for the GP fitting
- The `beta` parameter of `UCB` and `qUCB` can now also take negative values

## [0.12.1] - 2025-01-29
### Changed
Expand Down
28 changes: 12 additions & 16 deletions baybe/acquisition/acqfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@
from attr.converters import optional as optional_c
from attr.validators import optional as optional_v
from attrs import define, field, fields
from attrs.validators import ge, gt, instance_of, le
from attrs.validators import gt, instance_of, le
from typing_extensions import override

from baybe.acquisition.base import AcquisitionFunction
from baybe.searchspace import SearchSpace
from baybe.utils.basic import classproperty
from baybe.utils.sampling_algorithms import (
DiscreteSamplingMethod,
sample_numerical_df,
)
from baybe.utils.sampling_algorithms import DiscreteSamplingMethod, sample_numerical_df
from baybe.utils.validation import finite_float


########################################################################################
Expand Down Expand Up @@ -264,12 +262,15 @@ class UpperConfidenceBound(AcquisitionFunction):

abbreviation: ClassVar[str] = "UCB"

beta: float = field(converter=float, validator=ge(0.0), default=0.2)
beta: float = field(converter=float, validator=finite_float, default=0.2)
"""Trade-off parameter for mean and variance.
A value of zero makes the acquisition mechanism consider the posterior predictive
mean only, resulting in pure exploitation. Higher values shift the focus more and
more toward exploration.
* ``beta > 0``: Rewards uncertainty, takes more risk.
Limit ``inf``: Pure exploration
* ``beta < 0``: Punishes uncertainty, takes less risk.
Limit ``-inf``: Pure exploitation
* ``beta = 0``: Discards knowledge about uncertainty, i.e. neither rewards nor
punishes it, is risk-neutral.
"""


Expand All @@ -279,13 +280,8 @@ class qUpperConfidenceBound(AcquisitionFunction):

abbreviation: ClassVar[str] = "qUCB"

beta: float = field(converter=float, validator=ge(0.0), default=0.2)
"""Trade-off parameter for mean and variance.
A value of zero makes the acquisition mechanism consider the posterior predictive
mean only, resulting in pure exploitation. Higher values shift the focus more and
more toward exploration.
"""
beta: float = field(converter=float, validator=finite_float, default=0.2)
"""See :paramref:`UpperConfidenceBound.beta`."""


@define(frozen=True)
Expand Down
4 changes: 2 additions & 2 deletions tests/hypothesis_strategies/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ def _qNIPV_strategy(draw: st.DrawFn):
acquisition_functions = st.one_of(
st.builds(ExpectedImprovement),
st.builds(ProbabilityOfImprovement),
st.builds(UpperConfidenceBound, beta=finite_floats(min_value=0.0)),
st.builds(UpperConfidenceBound, beta=finite_floats()),
st.builds(PosteriorMean),
st.builds(PosteriorStandardDeviation, maximize=st.sampled_from([True, False])),
st.builds(qPosteriorStandardDeviation),
st.builds(LogExpectedImprovement),
st.builds(qExpectedImprovement),
st.builds(qProbabilityOfImprovement),
st.builds(qUpperConfidenceBound, beta=finite_floats(min_value=0.0)),
st.builds(qUpperConfidenceBound, beta=finite_floats()),
st.builds(qSimpleRegret),
st.builds(qLogExpectedImprovement),
st.builds(
Expand Down
5 changes: 2 additions & 3 deletions tests/validation/test_acqf_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,17 @@ def test_qNIPV(sampling_n_points, sampling_fraction, sampling_method, error, mat
qNIPV(**kwargs)


@pytest.mark.parametrize("acqf", [qNEI, qLogNEI])
@pytest.mark.parametrize("acqf", [qNEI, qLogNEI], ids=["qNEI", "qLogNEI"])
def test_EI(acqf):
"""Providing a non-Boolean argument to `prune_baseline` raises an error."""
with pytest.raises(TypeError):
acqf(1)


@pytest.mark.parametrize("acqf", [UCB, qUCB])
@pytest.mark.parametrize("acqf", [UCB, qUCB], ids=["UCB", "qUCB"])
@pytest.mark.parametrize(
("beta", "match"),
[
param(-1.0, "must be >= 0.0", id="negative"),
param("a", "could not convert", id="not_a_float"),
],
)
Expand Down

0 comments on commit 613bd3a

Please sign in to comment.