Skip to content

[Bug] Rational Quadratic Kernel with Deep GPs - crashes due to shape mismatch #2674

@mrlj-hash

Description

@mrlj-hash

🐛 Bug

When using the RQKernel with variational Deep GPs, a crash occurs because of a shape mismatch.

To reproduce

** Code snippet to reproduce **

###################### IMPORTS ##################################
from gpytorch.means import ConstantMean
from tqdm import tqdm
import torch
from torch.optim import Adam
import gpytorch as gpy
from torch.utils.data import TensorDataset, DataLoader
from gpytorch.models.deep_gps import DeepGP, DeepGPLayer
from gpytorch.distributions import MultivariateNormal
from gpytorch.likelihoods import GaussianLikelihood
# from gpytorch.likelihoods import MultitaskGaussianLikelihood (for comparison)
from gpytorch.optim import NGD
from gpytorch.variational import TrilNaturalVariationalDistribution
from gpytorch.variational import VariationalStrategy
from gpytorch.mlls import VariationalELBO as VELBO
from gpytorch.mlls import DeepApproximateMLL as DLL
from gpytorch.kernels import ScaleKernel, RQKernel
from gpytorch.kernels import RBFKernel, MaternKernel #(for comparison)
from scipy.io import loadmat
from math import floor

########################### SETUP ########################################
seed = 1
train_prop = 0.8
learning_rate = 0.1 # for NGD
batch_size = 64
hp_learn_rate = 0.01 # for Adam
num_ips = 20 # total number of inducing points (each GP in the network gets an equal number. If not integer, rounding)
num_tasks = 1 # number of tasks - assuming that multioutput deep GPs with 1 task is comparable to a single output DGP
num_samples = 50 # amount of sampling for deep GPs
num_hidden = 3 # number of GPs in the hidden layer
num_epochs = 20 # number of training epochs

torch.manual_seed(seed)

################ DATA SETUP ##############################################
data = torch.Tensor(loadmat('Data/elevators.mat')['data']) # the elevator dataset as used in the DGP notebook
X = data[:, :-1]
X = X - X.min(0)[0]
X = 2 * (X / X.max(0)[0]) - 1
y = data[:, -1]

train_n = int(floor(train_prop * len(X)))
train_x = X[:train_n, :].contiguous()
train_y = y[:train_n].contiguous()

test_x = X[train_n:, :].contiguous()
test_y = y[train_n:].contiguous()

# move to accelerator device if available, otherwise use the CPU
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
train_x, train_y, test_x, test_y = train_x.to(device), train_y.to(device), test_x.to(device), test_y.to(device)

# set up a loader for the training data
train_dataset = TensorDataset(train_x, train_y)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

num_cases = train_x.shape[0]
num_dims = train_x.shape[1]

# likelihood = MultitaskGaussianLikelihood(num_tasks = num_tasks) # assuming if num_tasks = 1, this is comparable to Gaussian Likelihood
likelihood = GaussianLikelihood()

inducing_points = train_x[:num_ips,:] # selecting a subset of the data to initialise the training points
inducing_points = torch.as_tensor(inducing_points)
inducing_points = inducing_points.to(device)

################## MODEL CLASSES ##############################################

# class to use for Deep GP layers
class DGPHiddenLayer(DeepGPLayer):
    def __init__(self, input_dims, output_dims, num_inducing=20):
        if output_dims is not None:
            inducing_points = torch.randn(output_dims, num_inducing, input_dims)
            batch_shape = torch.Size([output_dims])
        else:
            inducing_points = torch.randn(num_inducing, input_dims)
            batch_shape = torch.Size([])

        variational_distribution = TrilNaturalVariationalDistribution(num_inducing, batch_shape=batch_shape)
        variational_strategy = VariationalStrategy(self, inducing_points, variational_distribution, learn_inducing_locations = True)

        super(DGPHiddenLayer, self).__init__(variational_strategy, input_dims, output_dims)

        self.mean_module = ConstantMean(batch_shape=batch_shape)

        self.covar_module = ScaleKernel(RQKernel(batch_shape=batch_shape, ard_num_dims=input_dims),batch_shape=batch_shape,ard_num_dims=None)
        # included for comparison
        # self.covar_module = ScaleKernel(RQKernel(batch_shape=batch_shape),batch_shape=batch_shape,ard_num_dims=None)
        # self.covar_module = ScaleKernel(RBFKernel(batch_shape=batch_shape, ard_num_dims=input_dims),batch_shape=batch_shape,ard_num_dims=None)
        # self.covar_module = ScaleKernel(MaternKernel(batch_shape=batch_shape, ard_num_dims=input_dims, nu = 1.5),batch_shape=batch_shape,ard_num_dims=None)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)

# class for Deep GPs
class VariationalDeepGPModel(DeepGP):
    def __init__(self, inducing_points, num_hidden, num_tasks):

        num_inducing_points = inducing_points.size(-2)
        inducing_points = inducing_points.contiguous()

        hidden_layer = DGPHiddenLayer(input_dims=inducing_points.shape[-1],
                                      output_dims=num_hidden,
                                      num_inducing = int(num_inducing_points / (num_hidden + 1)))

        last_layer = DGPHiddenLayer(input_dims=hidden_layer.output_dims, output_dims=num_tasks,
                                    num_inducing = int(num_inducing_points / (num_hidden + 1)))
        super().__init__()

        self.hidden_layer = hidden_layer
        self.last_layer = last_layer

    def forward(self, inputs):
        hidden_rep1 = self.hidden_layer(inputs)
        output = self.last_layer(hidden_rep1)
        return output

model = VariationalDeepGPModel(inducing_points, num_hidden, num_tasks = None)

# move model and likelihood to appropriate device
model = model.to(device)
likelihood = likelihood.to(device)

################## TRAINING #####################################################

model.train()
likelihood.train()

variational_ngd_optimizer = NGD(model.variational_parameters(), num_data=train_y.size(0), lr=learning_rate)
hyperparameter_optimizer = Adam([
    {'params': model.hyperparameters()},
    {'params': likelihood.parameters()},
], lr=hp_learn_rate)

mll = DLL(VELBO(likelihood, model, train_x.shape[-2]))

epochs_iter = tqdm(range(num_epochs), desc="Epoch")

for i in epochs_iter:
    minibatch_iter = tqdm(train_loader, desc="Minibatch", leave=False)
    for x_batch, y_batch in minibatch_iter:
        ### Perform NGD step to optimise variational parameters
        with gpy.settings.num_likelihood_samples(num_samples):
            variational_ngd_optimizer.zero_grad()
            output = model(x_batch)
            loss = -mll(output, y_batch)
            loss.backward()
            variational_ngd_optimizer.step()

            # now optimise kernel hyperparameters
            hyperparameter_optimizer.zero_grad()
            output = model(x_batch)
            loss = -mll(output, y_batch)
            loss.backward()
            minibatch_iter.set_postfix(loss=loss.item())
            hyperparameter_optimizer.step()

** Stack trace/error message **

Traceback (most recent call last):
  File "/home/blj/pycharm-community-2024.2.4/plugins/python-ce/helpers/pydev/pydevconsole.py", line 364, in runcode
    coro = func()
           ^^^^^^
  File "<input>", line 152, in <module>
  File "/home/blj/anaconda3/envs/pythonProject/lib/python3.11/site-packages/gpytorch/module.py", line 82, in __call__
    outputs = self.forward(*inputs, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<input>", line 122, in forward
  File "/home/blj/anaconda3/envs/pythonProject/lib/python3.11/site-packages/gpytorch/models/deep_gps/deep_gp.py", line 78, in __call__
    inputs = torch.distributions.Normal(loc=inputs.mean, scale=inputs.variance.sqrt()).rsample()
                                                               ^^^^^^^^^^^^^^^
  File "/home/blj/anaconda3/envs/pythonProject/lib/python3.11/site-packages/gpytorch/distributions/multitask_multivariate_normal.py", line 278, in variance
    var = super().variance
          ^^^^^^^^^^^^^^^^
  File "/home/blj/anaconda3/envs/pythonProject/lib/python3.11/site-packages/gpytorch/distributions/multivariate_normal.py", line 366, in variance
    diag = self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/blj/anaconda3/envs/pythonProject/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 1419, in diagonal
    return self._diagonal()
           ^^^^^^^^^^^^^^^^
  File "/home/blj/anaconda3/envs/pythonProject/lib/python3.11/site-packages/linear_operator/operators/block_diag_linear_operator.py", line 94, in _diagonal
    res = self.base_linear_op._diagonal().contiguous()
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/blj/anaconda3/envs/pythonProject/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py", line 30, in _diagonal
    return sum(linear_op._diagonal().contiguous() for linear_op in self.linear_ops)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/blj/anaconda3/envs/pythonProject/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py", line 30, in <genexpr>
    return sum(linear_op._diagonal().contiguous() for linear_op in self.linear_ops)
               ^^^^^^^^^^^^^^^^^^^^^
  File "/home/blj/anaconda3/envs/pythonProject/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py", line 30, in _diagonal
    return sum(linear_op._diagonal().contiguous() for linear_op in self.linear_ops)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/blj/anaconda3/envs/pythonProject/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py", line 30, in <genexpr>
    return sum(linear_op._diagonal().contiguous() for linear_op in self.linear_ops)
               ^^^^^^^^^^^^^^^^^^^^^
  File "/home/blj/anaconda3/envs/pythonProject/lib/python3.11/site-packages/gpytorch/utils/memoize.py", line 59, in g
    return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/blj/anaconda3/envs/pythonProject/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 25, in wrapped
    output = method(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/blj/anaconda3/envs/pythonProject/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 118, in _diagonal
    res = super(Kernel, self.kernel).__call__(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/blj/anaconda3/envs/pythonProject/lib/python3.11/site-packages/gpytorch/module.py", line 82, in __call__
    outputs = self.forward(*inputs, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/blj/anaconda3/envs/pythonProject/lib/python3.11/site-packages/gpytorch/kernels/scale_kernel.py", line 109, in forward
    orig_output = self.base_kernel.forward(x1, x2, diag=diag, last_dim_is_batch=last_dim_is_batch, **params)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/blj/anaconda3/envs/pythonProject/lib/python3.11/site-packages/gpytorch/kernels/rq_kernel.py", line 74, in forward
    return postprocess_rq(
           ^^^^^^^^^^^^^^^
  File "/home/blj/anaconda3/envs/pythonProject/lib/python3.11/site-packages/gpytorch/kernels/rq_kernel.py", line 70, in postprocess_rq
    return (1 + dist_mat.div(2 * alpha)).pow(-alpha)
                ^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The size of tensor a (50) must match the size of tensor b (3) at non-singleton dimension 0

System information

Please complete the following information:

  • GPyTorch Version: 1.14
  • PyTorch Version: 2.7.1+cu126
  • Computer OS: Ubuntu 24.04.3 LTS

Additional context

With the same setup as above, both the RBFKernel and MaternKernel result in successful training. The problem also happens with using the RQKernel for Deep Sigma Point Processes, and when not using Automatic Relevance Detection.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions