Skip to content

Commit 14bc8b0

Browse files
committed
Add batch_shape property to GP model class
Implements #2301. TODO: Verify compatibility with the botorch setup of other models
1 parent f73fa7d commit 14bc8b0

File tree

9 files changed

+97
-19
lines changed

9 files changed

+97
-19
lines changed

gpytorch/models/approximate_gp.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/usr/bin/env python3
22

3+
import torch
4+
35
from .gp import GP
46
from .pyro import _PyroMixin # This will only contain functions if Pyro is installed
57

@@ -40,6 +42,11 @@ class ApproximateGP(GP, _PyroMixin):
4042
>>> # test_x = ...;
4143
>>> model(test_x) # Returns the approximate GP latent function at test_x
4244
>>> likelihood(model(test_x)) # Returns the (approximate) predictive posterior distribution at test_x
45+
46+
:ivar torch.Size batch_shape: The batch shape of the model. This is a batch shape from an I/O perspective,
47+
independent of the internal representation of the model. For a model with `(m)` outputs, a
48+
`test_batch_shape x q x d`-shaped input to the model in eval mode returns a
49+
distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
4350
"""
4451

4552
def __init__(self, variational_strategy):
@@ -49,6 +56,17 @@ def __init__(self, variational_strategy):
4956
def forward(self, x):
5057
raise NotImplementedError
5158

59+
@property
60+
def batch_shape(self) -> torch.Size:
61+
r"""The batch shape of the model.
62+
63+
This is a batch shape from an I/O perspective, independent of the internal
64+
representation of the model. For a model with `(m)` outputs, a
65+
`test_batch_shape x q x d`-shaped input to the model in eval mode returns a
66+
distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
67+
"""
68+
return self.variational_strategy.batch_shape
69+
5270
def pyro_guide(self, input, beta=1.0, name_prefix=""):
5371
r"""
5472
(For Pyro integration only). The component of a `pyro.guide` that

gpytorch/models/exact_gp.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ class ExactGP(GP):
5050
>>> # test_x = ...;
5151
>>> model(test_x) # Returns the GP latent function at test_x
5252
>>> likelihood(model(test_x)) # Returns the (approximate) predictive posterior distribution at test_x
53+
54+
:ivar torch.Size batch_shape: The batch shape of the model. This is a batch shape from an I/O perspective,
55+
independent of the internal representation of the model. For a model with `(m)` outputs, a
56+
`test_batch_shape x q x d`-shaped input to the model in eval mode returns a
57+
distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
5358
"""
5459

5560
def __init__(self, train_inputs, train_targets, likelihood):
@@ -71,6 +76,17 @@ def __init__(self, train_inputs, train_targets, likelihood):
7176

7277
self.prediction_strategy = None
7378

79+
@property
80+
def batch_shape(self) -> torch.Size:
81+
r"""The batch shape of the model.
82+
83+
This is a batch shape from an I/O perspective, independent of the internal
84+
representation of the model. For a model with `(m)` outputs, a
85+
`test_batch_shape x q x d`-shaped input to the model in eval mode returns a
86+
distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
87+
"""
88+
return self.train_inputs[0].shape[:-2]
89+
7490
@property
7591
def train_targets(self):
7692
return self._train_targets
@@ -160,8 +176,6 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
160176
"all test independent caches exist. Call the model on some data first!"
161177
)
162178

163-
model_batch_shape = self.train_inputs[0].shape[:-2]
164-
165179
if not isinstance(inputs, list):
166180
inputs = [inputs]
167181

@@ -184,17 +198,17 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
184198

185199
# Check whether we can properly broadcast batch dimensions
186200
try:
187-
torch.broadcast_shapes(model_batch_shape, target_batch_shape)
201+
torch.broadcast_shapes(self.batch_shape, target_batch_shape)
188202
except RuntimeError:
189203
raise RuntimeError(
190-
f"Model batch shape ({model_batch_shape}) and target batch shape "
204+
f"Model batch shape ({self.batch_shape}) and target batch shape "
191205
f"({target_batch_shape}) are not broadcastable."
192206
)
193207

194-
if len(model_batch_shape) > len(input_batch_shape):
195-
input_batch_shape = model_batch_shape
196-
if len(model_batch_shape) > len(target_batch_shape):
197-
target_batch_shape = model_batch_shape
208+
if len(self.batch_shape) > len(input_batch_shape):
209+
input_batch_shape = self.batch_shape
210+
if len(self.batch_shape) > len(target_batch_shape):
211+
target_batch_shape = self.batch_shape
198212

199213
# If input has no fantasy batch dimension but target does, we can save memory and computation by not
200214
# computing the covariance for each element of the batch. Therefore we don't expand the inputs to the

gpytorch/models/gp.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
11
#!/usr/bin/env python3
22

3+
import torch
4+
35
from ..module import Module
46

57

68
class GP(Module):
7-
pass
9+
@property
10+
def batch_shape(self) -> torch.Size:
11+
r"""The batch shape of the model.
12+
13+
This is a batch shape from an I/O perspective, independent of the internal
14+
representation of the model. For a model with `(m)` outputs, a
15+
`test_batch_shape x q x d`-shaped input to the model in eval mode returns a
16+
distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
17+
"""
18+
cls_name = self.__class__.__name__
19+
raise NotImplementedError(f"{cls_name} does not define batch_shape property")

gpytorch/models/model_list.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,24 @@ def __init__(self, *models):
3131
)
3232
self.likelihood = LikelihoodList(*[m.likelihood for m in models])
3333

34+
@property
35+
def batch_shape(self) -> torch.Size:
36+
r"""The batch shape of the model.
37+
38+
This is a batch shape from an I/O perspective, independent of the internal
39+
representation of the model. For a model with `(m)` outputs, a
40+
`test_batch_shape x q x d`-shaped input to the model in eval mode returns a
41+
distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
42+
"""
43+
batch_shape = self.models[0].batch_shape
44+
if all(batch_shape == m.batch_shape for m in self.models[1:]):
45+
return batch_shape
46+
# TODO: Allow broadcasting of model batch shapes
47+
raise NotImplementedError(
48+
f"`{self.__class__.__name__}.batch_shape` is only supported if all "
49+
"constituent models have the same `batch_shape`."
50+
)
51+
3452
def forward_i(self, i, *args, **kwargs):
3553
return self.models[i].forward(*args, **kwargs)
3654

gpytorch/test/model_test_case.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def test_forward_train(self):
3232
data = self.create_test_data()
3333
likelihood, labels = self.create_likelihood_and_labels()
3434
model = self.create_model(data, labels, likelihood)
35+
self.assertEqual(model.batch_shape, data.shape[:-2]) # test batch_shape property
3536
model.train()
3637
output = model(data)
3738
self.assertTrue(output.lazy_covariance_matrix.dim() == 2)
@@ -42,6 +43,7 @@ def test_batch_forward_train(self):
4243
batch_data = self.create_batch_test_data()
4344
likelihood, labels = self.create_batch_likelihood_and_labels()
4445
model = self.create_model(batch_data, labels, likelihood)
46+
self.assertEqual(model.batch_shape, batch_data.shape[:-2]) # test batch_shape property
4547
model.train()
4648
output = model(batch_data)
4749
self.assertTrue(output.lazy_covariance_matrix.dim() == 3)
@@ -52,6 +54,7 @@ def test_multi_batch_forward_train(self):
5254
batch_data = self.create_batch_test_data(batch_shape=torch.Size([2, 3]))
5355
likelihood, labels = self.create_batch_likelihood_and_labels(batch_shape=torch.Size([2, 3]))
5456
model = self.create_model(batch_data, labels, likelihood)
57+
self.assertEqual(model.batch_shape, batch_data.shape[:-2]) # test batch_shape property
5558
model.train()
5659
output = model(batch_data)
5760
self.assertTrue(output.lazy_covariance_matrix.dim() == 4)

gpytorch/variational/_variational_strategy.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,16 @@ def _expand_inputs(self, x: Tensor, inducing_points: Tensor) -> Tuple[Tensor, Te
9090
"""
9191
Pre-processing step in __call__ to make x the same batch_shape as the inducing points
9292
"""
93-
batch_shape = torch.broadcast_shapes(inducing_points.shape[:-2], x.shape[:-2])
93+
batch_shape = torch.broadcast_shapes(self.batch_shape, x.shape[:-2])
9494
inducing_points = inducing_points.expand(*batch_shape, *inducing_points.shape[-2:])
9595
x = x.expand(*batch_shape, *x.shape[-2:])
9696
return x, inducing_points
9797

98+
@property
99+
def batch_shape(self) -> torch.Size:
100+
r"""The batch shape of the variational strategy."""
101+
return self.inducing_points.shape[:-2]
102+
98103
@property
99104
def jitter_val(self) -> float:
100105
if self._jitter_val is None:

gpytorch/variational/lmc_variational_strategy.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,26 +116,24 @@ def __init__(
116116
Module.__init__(self)
117117
self.base_variational_strategy = base_variational_strategy
118118
self.num_tasks = num_tasks
119-
batch_shape = self.base_variational_strategy._variational_distribution.batch_shape
119+
vdist_batch_shape = self.base_variational_strategy._variational_distribution.batch_shape
120120

121121
# Check if no functions
122122
if latent_dim >= 0:
123123
raise RuntimeError(f"latent_dim must be a negative indexed batch dimension: got {latent_dim}.")
124-
if not (batch_shape[latent_dim] == num_latents or batch_shape[latent_dim] == 1):
124+
if not (vdist_batch_shape[latent_dim] == num_latents or vdist_batch_shape[latent_dim] == 1):
125125
raise RuntimeError(
126-
f"Mismatch in num_latents: got a variational distribution of batch shape {batch_shape}, "
126+
f"Mismatch in num_latents: got a variational distribution of batch shape {vdist_batch_shape}, "
127127
f"expected the function dim {latent_dim} to be {num_latents}."
128128
)
129129
self.num_latents = num_latents
130130
self.latent_dim = latent_dim
131131

132132
# Make the batch_shape
133-
self.batch_shape = list(batch_shape)
134-
del self.batch_shape[self.latent_dim]
135-
self.batch_shape = torch.Size(self.batch_shape)
133+
self._batch_shape = vdist_batch_shape[: self.latent_dim] + vdist_batch_shape[self.latent_dim + 1 :]
136134

137135
# LCM coefficients
138-
lmc_coefficients = torch.randn(*batch_shape, self.num_tasks)
136+
lmc_coefficients = torch.randn(*vdist_batch_shape, self.num_tasks)
139137
self.register_parameter("lmc_coefficients", torch.nn.Parameter(lmc_coefficients))
140138

141139
if jitter_val is None:
@@ -145,6 +143,11 @@ def __init__(
145143
else:
146144
self.jitter_val = jitter_val
147145

146+
@property
147+
def batch_shape(self) -> torch.Size:
148+
r"""The batch shape of the variational strategy."""
149+
return self._batch_shape
150+
148151
@property
149152
def prior_distribution(self) -> MultivariateNormal:
150153
return self.base_variational_strategy.prior_distribution

test/models/test_exact_gp.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ def test_batch_forward_then_nonbatch_forward_eval(self):
106106
batch_data = self.create_batch_test_data()
107107
likelihood, labels = self.create_batch_likelihood_and_labels()
108108
model = self.create_model(batch_data, labels, likelihood)
109+
110+
# test batch_shape property
111+
self.assertEqual(model.batch_shape, batch_data.shape[:-2])
112+
109113
model.eval()
110114
output = model(batch_data)
111115

test/models/test_variational_gp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212

1313
class GPClassificationModel(ApproximateGP):
1414
def __init__(self, train_x, use_inducing=False):
15-
variational_distribution = CholeskyVariationalDistribution(train_x.size(-2), batch_shape=train_x.shape[:-2])
16-
inducing_points = torch.randn(50, train_x.size(-1)) if use_inducing else train_x
15+
batch_shape = train_x.shape[:-2]
16+
variational_distribution = CholeskyVariationalDistribution(train_x.size(-2), batch_shape=batch_shape)
17+
inducing_points = torch.randn(*batch_shape, 50, train_x.size(-1)) if use_inducing else train_x
1718
strategy_cls = VariationalStrategy
1819
variational_strategy = strategy_cls(
1920
self, inducing_points, variational_distribution, learn_inducing_locations=use_inducing

0 commit comments

Comments
 (0)