-
Notifications
You must be signed in to change notification settings - Fork 578
Description
🐛 Bug
When a model is fantasized, the mean and covariance caches are recreated and reattached. However, the mean cache is attached with no args, while when it is called from DefaultPredictionStrategy.mean_cache, the settings.observation_nan_policy.value() is an arg. As a result, fantasy models recompute the mean cache more than necessary resulting in inefficient code, particularly when many fantasy models are repeatedly created in an optimization loop (for example).
To reproduce
MWE kindly adapted from #2631, however, beyond the set up I don't know if these issues are actually related.
Code snippet to reproduce
import torch
from gpytorch import settings as gpt_settings
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.models import ExactGP
torch.set_default_dtype(torch.double)
d = 10
mc_points = torch.rand(32, d, dtype=torch.double)
class SimpleGP(ExactGP):
def __init__(self, train_inputs, train_targets):
super().__init__(train_inputs, train_targets, GaussianLikelihood())
self.mean_module = ConstantMean()
self.covar_module = RBFKernel()
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)
gp = SimpleGP(
train_inputs=torch.rand(256, d, dtype=torch.double),
train_targets=torch.rand(256, dtype=torch.double),
).eval()
gp(torch.rand(5, d, dtype=torch.double)) # set the caches before fantasize.
print(gp.prediction_strategy._memoize_cache.keys())
X = torch.rand(128, 5, d, dtype=torch.double, requires_grad=False)
Y = torch.rand(128, 5, dtype=torch.double, requires_grad=False)
fantasy_model = gp.get_fantasy_model(inputs=X, targets=Y).eval()
print(fantasy_model.prediction_strategy._memoize_cache.keys())Stdout
dict_keys([('mean_cache', ('ignore',), b'\x80\x04}\x94.')])
dict_keys([('mean_cache', (), b'\x80\x04}\x94.'), ('covar_cache', (), b'\x80\x04}\x94.')])
Expected Behavior
I would expect for the mean_cache on the fantasy model to be attached with the same key as it is on self. In the above case, this means 'mean_cache', ('ignore',), b'\x80\x04}\x94.' instead of 'mean_cache', (), b'\x80\x04}\x94.'.
I would perhaps also expect that get_fantasy_model should be aware of the observation_nan_policy setting.
System information
Please complete the following information:
- GPyTorch Version: 1.15.dev37+g8433c0b86
- PyTorch Version: 2.8.0+cu128
- Computer OS: Ubuntu 20.04.6 LTS (Focal Fossa)
Additional context
Line where the mean_cache is accessed using an argument:
| return self._mean_cache(settings.observation_nan_policy.value()) |
Line where the mean_cache is added during get_fantasy_model:
| add_to_cache(fant_strat, "mean_cache", fant_mean_cache) |