Skip to content

[Bug] mean_cache not attached with correct args in get_fantasy_model #2669

@JackBuck

Description

@JackBuck

🐛 Bug

When a model is fantasized, the mean and covariance caches are recreated and reattached. However, the mean cache is attached with no args, while when it is called from DefaultPredictionStrategy.mean_cache, the settings.observation_nan_policy.value() is an arg. As a result, fantasy models recompute the mean cache more than necessary resulting in inefficient code, particularly when many fantasy models are repeatedly created in an optimization loop (for example).

To reproduce

MWE kindly adapted from #2631, however, beyond the set up I don't know if these issues are actually related.

Code snippet to reproduce

import torch
from gpytorch import settings as gpt_settings
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.models import ExactGP

torch.set_default_dtype(torch.double)

d = 10
mc_points = torch.rand(32, d, dtype=torch.double)


class SimpleGP(ExactGP):
    def __init__(self, train_inputs, train_targets):
        super().__init__(train_inputs, train_targets, GaussianLikelihood())
        self.mean_module = ConstantMean()
        self.covar_module = RBFKernel()

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)


gp = SimpleGP(
    train_inputs=torch.rand(256, d, dtype=torch.double),
    train_targets=torch.rand(256, dtype=torch.double),
).eval()
gp(torch.rand(5, d, dtype=torch.double))  # set the caches before fantasize.

print(gp.prediction_strategy._memoize_cache.keys())

X = torch.rand(128, 5, d, dtype=torch.double, requires_grad=False)
Y = torch.rand(128, 5, dtype=torch.double, requires_grad=False)

fantasy_model = gp.get_fantasy_model(inputs=X, targets=Y).eval()

print(fantasy_model.prediction_strategy._memoize_cache.keys())

Stdout

dict_keys([('mean_cache', ('ignore',), b'\x80\x04}\x94.')])
dict_keys([('mean_cache', (), b'\x80\x04}\x94.'), ('covar_cache', (), b'\x80\x04}\x94.')])

Expected Behavior

I would expect for the mean_cache on the fantasy model to be attached with the same key as it is on self. In the above case, this means 'mean_cache', ('ignore',), b'\x80\x04}\x94.' instead of 'mean_cache', (), b'\x80\x04}\x94.'.

I would perhaps also expect that get_fantasy_model should be aware of the observation_nan_policy setting.

System information

Please complete the following information:

  • GPyTorch Version: 1.15.dev37+g8433c0b86
  • PyTorch Version: 2.8.0+cu128
  • Computer OS: Ubuntu 20.04.6 LTS (Focal Fossa)

Additional context

Line where the mean_cache is accessed using an argument:

return self._mean_cache(settings.observation_nan_policy.value())

Line where the mean_cache is added during get_fantasy_model:

add_to_cache(fant_strat, "mean_cache", fant_mean_cache)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions