-
Notifications
You must be signed in to change notification settings - Fork 573
Open
Labels
Description
🐛 Bug
Getting a fantasy model for a simple multi-task GP throws an error
To reproduce
Here is a minimum working example of the bug
import torch
import gpytorch
class MultitaskGPModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood, n_tasks):
super(MultitaskGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.MultitaskMean(
gpytorch.means.ConstantMean(), num_tasks=n_tasks
)
self.covar_module = gpytorch.kernels.MultitaskKernel(
gpytorch.kernels.RBFKernel(), num_tasks=n_tasks, rank=1
)
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)
input_dim = 1
output_dim = 2
n_train = 10
train_x = torch.randn(n_train, input_dim)
train_y = torch.randn(n_train, output_dim)
likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=output_dim)
model = MultitaskGPModel(train_x, train_y, likelihood, output_dim)
model.train()
model.eval()
# get a posterior to fill in caches
model(torch.randn(n_train, input_dim))
# Generate some new data and get fantasy model
n_new = 5
new_x = torch.randn(n_new, input_dim)
new_y = torch.randn(n_new, output_dim)
model.get_fantasy_model(new_x, new_y)
** Stack trace/error message **
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/acorso/micromamba/envs/newenv/lib/python3.12/site-packages/gpytorch/models/exact_gp.py", line 239, in get_fantasy_model
new_model.prediction_strategy = old_pred_strat.get_fantasy_strategy(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/acorso/micromamba/envs/newenv/lib/python3.12/site-packages/gpytorch/models/exact_prediction_strategies.py", line 196, in get_fantasy_strategy
small_system_rhs = targets - fant_mean - ftcm
~~~~~~~~~~~~~~~~~~~~^~~~~~
RuntimeError: The size of tensor a (2) must match the size of tensor b (10) at non-singleton dimension 1
Expected Behavior
The fantasy model with the appropriately updated cache should be returned
System information
- gpytorch 1.12
- torch 2.4.0+cu121
- MacOS
Additional context
This has already been a topic of discussion #800 and #805 and a PR was merged that supposedly implemented this feature #2317. However, the test that was added only works because only a single additional datapoint was added to produce the fantasy model. If you switch n_new=1
in the example I provide above, it also runs without error but I'm skeptical that the right thing is happening, if it doesn't work for more than 1 additional point.
abirhossen786