Skip to content

[Bug] Batch Independent Multioutput GP always uses DefaultPredictionStrategy #2659

@lahramon

Description

@lahramon

🐛 Bug

When creating an ExactGP model for multiple independent output dimensions according to this documentation, the created BlockInterleavedLinearOperator is not a LazyEvaluatedKernelTensor instance.
Thus, at this point in the code,

if isinstance(train_train_covar, LazyEvaluatedKernelTensor):
    cls = train_train_covar.kernel.prediction_strategy
else:
    cls = DefaultPredictionStrategy
return cls(train_inputs, train_prior_dist, train_labels, likelihood)

the resulting GP will always resort to the DefaultPredictionStrategy, irrespective of the kernel.prediction_strategy defined by the kernel.
E.g. for inducing-point kernels, this means that their sparse structure is not exploited, discarding their computational benefits.

To reproduce

import torch
import gpytorch

nx = 2
ny = 3

train_x = torch.rand(100, nx)  # 100 samples, 2 features
train_y = torch.rand(100, ny)  # 100 samples, 1 target variable
test_x = torch.rand(20, nx)  # 20 samples, 2 features

inducing_points = torch.rand(ny, 10, nx)  # 10 inducing points, 2 features

likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(
    num_tasks=ny,
)


class BatchIndependentInducingPointGpModel(gpytorch.models.ExactGP):
    def __init__(
        self,
        train_x,
        train_y,
        likelihood,
        inducing_points,
        use_inducing_kernel: bool = True,
    ):
        super().__init__(train_x, train_y, likelihood)

        ny = train_y.shape[1]

        self.mean_module = gpytorch.means.ZeroMean(batch_shape=torch.Size([ny]))
        base_covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(
                batch_shape=torch.Size([ny]),
            ),
            batch_shape=torch.Size([ny]),
        )

        if use_inducing_kernel:
            self.covar_module = gpytorch.kernels.InducingPointKernel(
                base_covar_module,
                inducing_points=inducing_points,
                likelihood=likelihood,
            )
        else:
            self.covar_module = base_covar_module

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)

        return gpytorch.distributions.MultitaskMultivariateNormal.from_batch_mvn(
            gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
        )


gp_model_inducing = BatchIndependentInducingPointGpModel(
    train_x=train_x,
    train_y=train_y,
    likelihood=likelihood,
    inducing_points=inducing_points,
)
gp_model_inducing.eval()
gp_model_inducing.likelihood.eval()

prediction = gp_model_inducing(test_x)

assert isinstance(
    gp_model_inducing.prediction_strategy,
    gpytorch.models.exact_prediction_strategies.SGPRPredictionStrategy,
)

Expected Behavior

The gp_model.prediction_strategy should be set to the optimized gpytorch.models.exact_prediction_strategies.SGPRPredictionStrategy defined for the inducing-point kernel.

System information

Please complete the following information:

  • GPyTorch Version 1.14
  • PyTorch Version 2.71
  • Pop!OS 24.04

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions