Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Move input transforms to GPyTorch #1372

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions botorch/models/approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def __init__(
variational_distribution: Optional[_VariationalDistribution] = None,
variational_strategy: Type[_VariationalStrategy] = VariationalStrategy,
inducing_points: Optional[Union[Tensor, int]] = None,
input_transform: Optional[InputTransform] = None,
) -> None:
r"""
Args:
Expand Down Expand Up @@ -252,6 +253,8 @@ def __init__(
super().__init__(variational_strategy=variational_strategy)
self.mean_module = mean_module
self.covar_module = covar_module
if input_transform is not None:
self.input_transform = input_transform

def forward(self, X) -> MultivariateNormal:
mean_x = self.mean_module(X)
Expand Down Expand Up @@ -373,14 +376,13 @@ def __init__(
variational_distribution=variational_distribution,
variational_strategy=variational_strategy,
inducing_points=inducing_points,
input_transform=input_transform,
)

super().__init__(model=model, likelihood=likelihood, num_outputs=num_outputs)

if outcome_transform is not None:
self.outcome_transform = outcome_transform
if input_transform is not None:
self.input_transform = input_transform

# for model fitting utilities
# TODO: make this a flag?
Expand Down
79 changes: 10 additions & 69 deletions botorch/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,8 @@ class Model(Module, ABC):

Model cannot be used directly; it only defines an API for other BoTorch
models.

Args:
_has_transformed_inputs: A boolean denoting whether `train_inputs` are currently
stored as transformed or not.
_original_train_inputs: A Tensor storing the original train inputs for use in
`_revert_to_original_inputs`. Note that this is necessary since
transform / untransform cycle introduces numerical errors which lead
to upstream errors during training.
"""

_has_transformed_inputs: bool = False
_original_train_inputs: Optional[Tensor] = None

@abstractmethod
def posterior(
self,
Expand Down Expand Up @@ -199,57 +188,11 @@ def transform_inputs(
Returns:
A tensor of transformed inputs
"""
if input_transform is not None:
input_transform.to(X)
return input_transform(X)
try:
return self.input_transform(X)
except AttributeError:
return X

def _set_transformed_inputs(self) -> None:
r"""Update training inputs with transformed inputs."""
if hasattr(self, "input_transform") and not self._has_transformed_inputs:
if hasattr(self, "train_inputs"):
self._original_train_inputs = self.train_inputs[0]
with torch.no_grad():
X_tf = self.input_transform.preprocess_transform(
self.train_inputs[0]
)
self.set_train_data(X_tf, strict=False)
self._has_transformed_inputs = True
else:
warnings.warn(
"Could not update `train_inputs` with transformed inputs "
f"since {self.__class__.__name__} does not have a `train_inputs` "
"attribute. Make sure that the `input_transform` is applied to "
"both the train inputs and test inputs.",
RuntimeWarning,
)

def _revert_to_original_inputs(self) -> None:
r"""Revert training inputs back to original."""
if hasattr(self, "input_transform") and self._has_transformed_inputs:
self.set_train_data(self._original_train_inputs, strict=False)
self._has_transformed_inputs = False

def eval(self) -> Model:
r"""Puts the model in `eval` mode and sets the transformed inputs."""
self._set_transformed_inputs()
return super().eval()

def train(self, mode: bool = True) -> Model:
r"""Puts the model in `train` mode and reverts to the original inputs.

Args:
mode: A boolean denoting whether to put in `train` or `eval` mode.
If `False`, model is put in `eval` mode.
"""
if mode:
self._revert_to_original_inputs()
else:
self._set_transformed_inputs()
return super().train(mode=mode)
warnings.warn(
"`Model.transform_inputs` is deprecated. Input transforms are applied at GPyTorch model `__call__` instead.",
DeprecationWarning,
)
return X


class ModelList(Model):
Expand Down Expand Up @@ -413,10 +356,8 @@ def transform_inputs(self, X: Tensor) -> List[Tensor]:
Returns:
A list of tensors of transformed inputs.
"""
transformed_X_list = []
for model in self.models:
try:
transformed_X_list.append(model.input_transform(X))
except AttributeError:
transformed_X_list.append(X)
return transformed_X_list
warnings.warn(
"`Model.transform_inputs` is deprecated. Input transforms are applied at GPyTorch model `__call__` instead.",
DeprecationWarning,
)
return [X for _ in self.models]
10 changes: 0 additions & 10 deletions botorch/models/model_list_gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,3 @@ def subset_output(self, idcs: List[int]) -> ModelListGP:
The current model, subset to the specified output indices.
"""
return self.__class__(*[deepcopy(self.models[i]) for i in idcs])

def _set_transformed_inputs(self) -> None:
r"""Update training inputs with transformed inputs."""
for m in self.models:
m._set_transformed_inputs()

def _revert_to_original_inputs(self) -> None:
r"""Revert training inputs back to original."""
for m in self.models:
m._revert_to_original_inputs()
53 changes: 6 additions & 47 deletions botorch/models/transforms/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,20 @@ class InputTransform(ABC):
transform_on_train: bool
transform_on_fantasize: bool

def forward(self, X: Tensor) -> Tensor:
def forward(self, X: Tensor, is_training_input: bool) -> Tensor:
r"""Transform the inputs to a model.

Args:
X: A `batch_shape x n x d`-dim tensor of inputs.
is_training_input: A boolean denoting whether the input is a training input.
If true, only the transforms with `transform_on_train=True` are applied.
Otherwise, the transform will be applied based on `transform_on_eval`
and `transform_on_fantasize` options.

Returns:
A `batch_shape x n' x d`-dim tensor of transformed inputs.
"""
if self.training:
if is_training_input:
if self.transform_on_train:
return self.transform(X)
elif self.transform_on_eval:
Expand Down Expand Up @@ -123,33 +127,6 @@ def equals(self, other: InputTransform) -> bool:
)
)

def preprocess_transform(self, X: Tensor) -> Tensor:
r"""Apply transforms for preprocessing inputs.

The main use cases for this method are 1) to preprocess training data
before calling `set_train_data` and 2) preprocess `X_baseline` for noisy
acquisition functions so that `X_baseline` is "preprocessed" with the
same transformations as the cached training inputs.

Args:
X: A `batch_shape x n x d`-dim tensor of inputs.

Returns:
A `batch_shape x n x d`-dim tensor of (transformed) inputs.
"""
if self.transform_on_train:
# We need to disable learning of bounds here.
# See why: https://github.com/pytorch/botorch/issues/1078.
if hasattr(self, "learn_bounds"):
learn_bounds = self.learn_bounds
self.learn_bounds = False
result = self.transform(X)
self.learn_bounds = learn_bounds
return result
else:
return self.transform(X)
return X


class ChainedInputTransform(InputTransform, ModuleDict):
r"""An input transform representing the chaining of individual transforms."""
Expand Down Expand Up @@ -224,24 +201,6 @@ def equals(self, other: InputTransform) -> bool:
t1 == t2 for t1, t2 in zip(self.values(), other.values())
)

def preprocess_transform(self, X: Tensor) -> Tensor:
r"""Apply transforms for preprocessing inputs.

The main use cases for this method are 1) to preprocess training data
before calling `set_train_data` and 2) preprocess `X_baseline` for noisy
acquisition functions so that `X_baseline` is "preprocessed" with the
same transformations as the cached training inputs.

Args:
X: A `batch_shape x n x d`-dim tensor of inputs.

Returns:
A `batch_shape x n x d`-dim tensor of (transformed) inputs.
"""
for tf in self.values():
X = tf.preprocess_transform(X)
return X


class ReversibleInputTransform(InputTransform, ABC):
r"""An abstract class for a reversible input transform.
Expand Down