-
Notifications
You must be signed in to change notification settings - Fork 574
Open
Labels
Description
🐛 Bug
When creating an ExactGP
model for multiple independent output dimensions according to this documentation, the created BlockInterleavedLinearOperator
is not a LazyEvaluatedKernelTensor
instance.
Thus, at this point in the code,
if isinstance(train_train_covar, LazyEvaluatedKernelTensor):
cls = train_train_covar.kernel.prediction_strategy
else:
cls = DefaultPredictionStrategy
return cls(train_inputs, train_prior_dist, train_labels, likelihood)
the resulting GP will always resort to the DefaultPredictionStrategy
, irrespective of the kernel.prediction_strategy
defined by the kernel.
E.g. for inducing-point kernels, this means that their sparse structure is not exploited, discarding their computational benefits.
To reproduce
import torch
import gpytorch
nx = 2
ny = 3
train_x = torch.rand(100, nx) # 100 samples, 2 features
train_y = torch.rand(100, ny) # 100 samples, 1 target variable
test_x = torch.rand(20, nx) # 20 samples, 2 features
inducing_points = torch.rand(ny, 10, nx) # 10 inducing points, 2 features
likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(
num_tasks=ny,
)
class BatchIndependentInducingPointGpModel(gpytorch.models.ExactGP):
def __init__(
self,
train_x,
train_y,
likelihood,
inducing_points,
use_inducing_kernel: bool = True,
):
super().__init__(train_x, train_y, likelihood)
ny = train_y.shape[1]
self.mean_module = gpytorch.means.ZeroMean(batch_shape=torch.Size([ny]))
base_covar_module = gpytorch.kernels.ScaleKernel(
gpytorch.kernels.RBFKernel(
batch_shape=torch.Size([ny]),
),
batch_shape=torch.Size([ny]),
)
if use_inducing_kernel:
self.covar_module = gpytorch.kernels.InducingPointKernel(
base_covar_module,
inducing_points=inducing_points,
likelihood=likelihood,
)
else:
self.covar_module = base_covar_module
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultitaskMultivariateNormal.from_batch_mvn(
gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
)
gp_model_inducing = BatchIndependentInducingPointGpModel(
train_x=train_x,
train_y=train_y,
likelihood=likelihood,
inducing_points=inducing_points,
)
gp_model_inducing.eval()
gp_model_inducing.likelihood.eval()
prediction = gp_model_inducing(test_x)
assert isinstance(
gp_model_inducing.prediction_strategy,
gpytorch.models.exact_prediction_strategies.SGPRPredictionStrategy,
)
Expected Behavior
The gp_model.prediction_strategy
should be set to the optimized gpytorch.models.exact_prediction_strategies.SGPRPredictionStrategy
defined for the inducing-point kernel.
System information
Please complete the following information:
- GPyTorch Version 1.14
- PyTorch Version 2.71
- Pop!OS 24.04