-
Notifications
You must be signed in to change notification settings - Fork 411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Move input transforms to GPyTorch #1372
base: main
Are you sure you want to change the base?
Conversation
Summary: This diff presents a minimal implementation of input transforms in GPyTorch. What this does: * Moves the `transform_inputs` from BoTorch `Model` to GPyTorch `GP` class, with some modifications to explicitly identify whether given inputs are train or test inputs. * Modifies the `InputTransform.forward` call to use `is_training_input` argument instead of `self.training` check to apply the transforms that have `transform_on_train=True`. * Removes `preprocess_transform` method since this is no-longer needed. * For `ExactGP` models, it transforms both train and test inputs in `__call__`. For `train_inputs` it always uses `is_training_input=True`. For generic `inputs`, it uses `is_training_input=self.training` which signals that these are training inputs when the model is in `train` mode, and that these are test inputs when the model is in `eval` mode. * For `ApproximateGP` models, it applies the transform to `inputs` in `__call__` using `is_training_input=self.training`. This again signifies whether the given inputs are train or test inputs based on the mode of the model. Note that this NEVER transforms `inducing_points`, thus fixes the previous bug with `inducing_points` getting transformed in `train` but not getting transformed in `eval`. It is expected that the user will define inducing points in the appropriate space (mostly the normalized space / unit cube). * For BoTorch `SingleTaskVariationalGP`, it moves the `input_transform` attribute down to `_SingleTaskVariationalGP`, which is the actual `ApproximateGP` instance. This makes the transform accessible from GPyTorch. What this doesn't do: * It doesn't do anything about `DeterministicModel`s. Those will still need to deal with their own transforms, which is not implemented here. If we make `Model` inherit from `GP`, we can keep the existing setup with very minimal changes. * It does not clean up the call sites for `self.transform_inputs`. This is just made into a no-op and the clean-up is left for later. * It does not upstream the abstract `InputTransform` classes to GPyTorch. That'll be done if we decide to go forward with this design. * It does not touch `PairwiseGP`. `PairwiseGP` has some non-standard use of input transforms, so it needs an audit to make sure things still work fine. * I didn't look into `ApproximateGP.fantasize`. This may need some changes similar to `ExactGP.get_fantasy_model`. * It does not support `PyroGP` and `DeepGP`. Differential Revision: D39147547 fbshipit-source-id: ed2745b0ff666a13764759e1511a139c228d1d39
This pull request was exported from Phabricator. Differential Revision: D39147547 |
I like this design!
I think we probably want to have sth like a gpytorch
@gpleiss do you have any high-level feedback on the transform setup (https://github.com/pytorch/botorch/tree/main/botorch/models/transforms) that we'd want to incorporate when upstreaming those? One point that @j-wilson had brought up is that if the transforms are expensive and not learnable (e.g. a pre-fit NN feature extractor) then repeatedly applying it to the same inputs during training (for the full batch case of exact GPs anyway) could be quite wasteful. Is there an elegant solution to this by means of caching the transformed values of the training data and evicting that cache when they are reset? |
@Balandat I really like the botorch API, and this would be super useful to have upstream in GPyTorch!
There probably is an elegant way to do this, but nothing really comes to mind. We should circle back to this at some point, but at the very least a power user could (e.g.) apply a pre-trained NN to the inputs without using the transforms API. |
Summary:
This diff presents a minimal implementation of input transforms in GPyTorch. See cornellius-gp/gpytorch#2114 for GPyTorch side of these changes.
What this does:
transform_inputs
from BoTorchModel
to GPyTorchGP
class, with some modifications to explicitly identify whether given inputs are train or test inputs.InputTransform.forward
call to useis_training_input
argument instead ofself.training
check to apply the transforms that havetransform_on_train=True
.preprocess_transform
method since this is no-longer needed.ExactGP
models, it transforms both train and test inputs in__call__
. Fortrain_inputs
it always usesis_training_input=True
. For genericinputs
, it usesis_training_input=self.training
which signals that these are training inputs when the model is intrain
mode, and that these are test inputs when the model is ineval
mode.ApproximateGP
models, it applies the transform toinputs
in__call__
usingis_training_input=self.training
. This again signifies whether the given inputs are train or test inputs based on the mode of the model. Note that this NEVER transformsinducing_points
, thus fixes the previous bug withinducing_points
getting transformed intrain
but not getting transformed ineval
. It is expected that the user will define inducing points in the appropriate space (mostly the normalized space / unit cube).SingleTaskVariationalGP
, it moves theinput_transform
attribute down to_SingleTaskVariationalGP
, which is the actualApproximateGP
instance. This makes the transform accessible from GPyTorch.What this doesn't do:
DeterministicModel
s. Those will still need to deal with their own transforms, which is not implemented here. If we makeModel
inherit fromGP
, we can keep the existing setup with very minimal changes.self.transform_inputs
. This is just made into a no-op and the clean-up is left for later.InputTransform
classes to GPyTorch. That'll be done if we decide to go forward with this design.PairwiseGP
.PairwiseGP
has some non-standard use of input transforms, so it needs an audit to make sure things still work fine.ApproximateGP.fantasize
. This may need some changes similar toExactGP.get_fantasy_model
.PyroGP
andDeepGP
.Differential Revision: D39147547