Skip to content

Commit 3fac12d

Browse files
committed
Add batch_shape property to GP model class
Implements #2301. TODO: - support for approximate GP models - unit tests - verify compatibility with the botorch setup of other models
1 parent a2b5fd8 commit 3fac12d

File tree

3 files changed

+49
-10
lines changed

3 files changed

+49
-10
lines changed

gpytorch/models/exact_gp.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,17 @@ def __init__(self, train_inputs, train_targets, likelihood):
7171

7272
self.prediction_strategy = None
7373

74+
@property
75+
def batch_shape(self) -> torch.Size:
76+
r"""The batch shape of the model.
77+
78+
This is a batch shape from an I/O perspective, independent of the internal
79+
representation of the model. For a model with `(m)` outputs, a
80+
`test_batch_shape x q x d`-shaped input to the model in eval mode returns a
81+
distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
82+
"""
83+
return self.train_inputs[0].shape[:-2]
84+
7485
@property
7586
def train_targets(self):
7687
return self._train_targets
@@ -160,9 +171,7 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
160171
"all test independent caches exist. Call the model on some data first!"
161172
)
162173

163-
model_batch_shape = self.train_inputs[0].shape[:-2]
164-
165-
if self.train_targets.dim() > len(model_batch_shape) + 1:
174+
if self.train_targets.dim() > len(self.batch_shape) + 1:
166175
raise RuntimeError("Cannot yet add fantasy observations to multitask GPs, but this is coming soon!")
167176

168177
if not isinstance(inputs, list):
@@ -182,17 +191,17 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
182191

183192
# Check whether we can properly broadcast batch dimensions
184193
try:
185-
torch.broadcast_shapes(model_batch_shape, target_batch_shape)
194+
torch.broadcast_shapes(self.batch_shape, target_batch_shape)
186195
except RuntimeError:
187196
raise RuntimeError(
188-
f"Model batch shape ({model_batch_shape}) and target batch shape "
197+
f"Model batch shape ({self.batch_shape}) and target batch shape "
189198
f"({target_batch_shape}) are not broadcastable."
190199
)
191200

192-
if len(model_batch_shape) > len(input_batch_shape):
193-
input_batch_shape = model_batch_shape
194-
if len(model_batch_shape) > len(target_batch_shape):
195-
target_batch_shape = model_batch_shape
201+
if len(self.batch_shape) > len(input_batch_shape):
202+
input_batch_shape = self.batch_shape
203+
if len(self.batch_shape) > len(target_batch_shape):
204+
target_batch_shape = self.batch_shape
196205

197206
# If input has no fantasy batch dimension but target does, we can save memory and computation by not
198207
# computing the covariance for each element of the batch. Therefore we don't expand the inputs to the

gpytorch/models/gp.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
11
#!/usr/bin/env python3
22

3+
import torch
4+
35
from ..module import Module
46

57

68
class GP(Module):
7-
pass
9+
@property
10+
def batch_shape(self) -> torch.Size:
11+
r"""The batch shape of the model.
12+
13+
This is a batch shape from an I/O perspective, independent of the internal
14+
representation of the model. For a model with `(m)` outputs, a
15+
`test_batch_shape x q x d`-shaped input to the model in eval mode returns a
16+
distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
17+
"""
18+
cls_name = self.__class__.__name__
19+
raise NotImplementedError(f"{cls_name} does not define batch_shape property")

gpytorch/models/model_list.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,24 @@ def __init__(self, *models):
4040
def num_outputs(self):
4141
return len(self.models)
4242

43+
@property
44+
def batch_shape(self) -> torch.Size:
45+
r"""The batch shape of the model.
46+
47+
This is a batch shape from an I/O perspective, independent of the internal
48+
representation of the model. For a model with `(m)` outputs, a
49+
`test_batch_shape x q x d`-shaped input to the model in eval mode returns a
50+
distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
51+
"""
52+
batch_shape = self.models[0].batch_shape
53+
if all(batch_shape == m.batch_shape for m in self.models[1:]):
54+
return batch_shape
55+
# TODO: Allow broadcasting of model batch shapes
56+
raise NotImplementedError(
57+
f"`{self.__class__.__name__}.batch_shape` is only supported if all "
58+
"constituent models have the same `batch_shape`."
59+
)
60+
4361
def forward_i(self, i, *args, **kwargs):
4462
return self.models[i].forward(*args, **kwargs)
4563

0 commit comments

Comments
 (0)