[RFC] Move input transforms to GPyTorch #2114
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This diff presents a minimal implementation of input transforms in GPyTorch, as requested in #1652. This should be viewed together with meta-pytorch/botorch#1372. The input transforms themselves are currently implemented in https://github.com/pytorch/botorch/blob/cdd668d18b2a7e35bed09b7a2b2fca40e5fd2067/botorch/models/transforms/input.py
What this does:
transform_inputsfrom BoTorchModelto GPyTorchGPclass, with some modifications to explicitly identify whether given inputs are train or test inputs.InputTransform.forwardcall to useis_training_inputargument instead ofself.trainingcheck to apply the transforms that havetransform_on_train=True.preprocess_transformmethod since this is no-longer needed.ExactGPmodels, it transforms both train and test inputs in__call__. Fortrain_inputsit always usesis_training_input=True. For genericinputs, it usesis_training_input=self.trainingwhich signals that these are training inputs when the model is intrainmode, and that these are test inputs when the model is inevalmode.ApproximateGPmodels, it applies the transform toinputsin__call__usingis_training_input=self.training. This again signifies whether the given inputs are train or test inputs based on the mode of the model. Note that this NEVER transformsinducing_points, thus fixes the previous bug withinducing_pointsgetting transformed intrainbut not getting transformed ineval. It is expected that the user will define inducing points in the appropriate space (mostly the normalized space / unit cube).SingleTaskVariationalGP, it moves theinput_transformattribute down to_SingleTaskVariationalGP, which is the actualApproximateGPinstance. This makes the transform accessible from GPyTorch.What this doesn't do:
DeterministicModels. Those will still need to deal with their own transforms, which is not implemented here. If we makeModelinherit fromGP, we can keep the existing setup with very minimal changes.self.transform_inputs. This is just made into a no-op and the clean-up is left for later.InputTransformclasses to GPyTorch. That'll be done if we decide to go forward with this design.PairwiseGP.PairwiseGPhas some non-standard use of input transforms, so it needs an audit to make sure things still work fine.ApproximateGP.fantasize. This may need some changes similar toExactGP.get_fantasy_model.PyroGPandDeepGP.