Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

probabilistic reparameterization tutorial #1534

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 7 additions & 19 deletions botorch/acquisition/fixed_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
from torch import Tensor
from torch.nn import Module


class FixedFeatureAcquisitionFunction(AcquisitionFunction):
class FixedFeatureAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
"""A wrapper around AquisitionFunctions to fix a subset of features.

Example:
Expand Down Expand Up @@ -56,8 +56,7 @@ def __init__(
combination of `Tensor`s and numbers which can be broadcasted
to form a tensor with trailing dimension size of `d_f`.
"""
Module.__init__(self)
self.acq_func = acq_function
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function)
dtype = torch.float
device = torch.device("cpu")
self.d = d
Expand Down Expand Up @@ -126,24 +125,13 @@ def forward(self, X: Tensor):
X_full = self._construct_X_full(X)
return self.acq_func(X_full)

@property
def X_pending(self):
r"""Return the `X_pending` of the base acquisition function."""
try:
return self.acq_func.X_pending
except (ValueError, AttributeError):
raise ValueError(
f"Base acquisition function {type(self.acq_func).__name__} "
"does not have an `X_pending` attribute."
)

@X_pending.setter
def X_pending(self, X_pending: Optional[Tensor]):
def set_X_pending(self, X_pending: Optional[Tensor]):
r"""Sets the `X_pending` of the base acquisition function."""
if X_pending is not None:
self.acq_func.X_pending = self._construct_X_full(X_pending)
full_X_pending = self._construct_X_full(X_pending)
else:
self.acq_func.X_pending = X_pending
full_X_pending = None
self.acq_func.set_X_pending(full_X_pending)

def _construct_X_full(self, X: Tensor) -> Tensor:
r"""Constructs the full input for the base acquisition function.
Expand Down
24 changes: 5 additions & 19 deletions botorch/acquisition/penalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@

import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.analytic import AnalyticAcquisitionFunction
from botorch.acquisition.objective import GenericMCObjective
from botorch.exceptions import UnsupportedError
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
from torch import Tensor


Expand Down Expand Up @@ -139,7 +138,7 @@ def forward(self, X: Tensor) -> Tensor:
return regularization_term


class PenalizedAcquisitionFunction(AcquisitionFunction):
class PenalizedAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
r"""Single-outcome acquisition function regularized by the given penalty.

The usage is similar to:
Expand All @@ -161,29 +160,16 @@ def __init__(
penalty_func: The regularization function.
regularization_parameter: Regularization parameter used in optimization.
"""
super().__init__(model=raw_acqf.model)
self.raw_acqf = raw_acqf
AcquisitionFunction.__init__(self, model=raw_acqf.model)
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=raw_acqf)
self.penalty_func = penalty_func
self.regularization_parameter = regularization_parameter

def forward(self, X: Tensor) -> Tensor:
raw_value = self.raw_acqf(X=X)
raw_value = self.acq_func(X=X)
penalty_term = self.penalty_func(X)
return raw_value - self.regularization_parameter * penalty_term

@property
def X_pending(self) -> Optional[Tensor]:
return self.raw_acqf.X_pending

def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None:
if not isinstance(self.raw_acqf, AnalyticAcquisitionFunction):
self.raw_acqf.set_X_pending(X_pending=X_pending)
else:
raise UnsupportedError(
"The raw acquisition function is Analytic and does not account "
"for X_pending yet."
)


def group_lasso_regularizer(X: Tensor, groups: List[List[int]]) -> Tensor:
r"""Computes the group lasso regularization function for the given point.
Expand Down
Loading