@@ -50,6 +50,11 @@ class ExactGP(GP):
5050 >>> # test_x = ...;
5151 >>> model(test_x) # Returns the GP latent function at test_x
5252 >>> likelihood(model(test_x)) # Returns the (approximate) predictive posterior distribution at test_x
53+
54+ :ivar torch.Size batch_shape: The batch shape of the model. This is a batch shape from an I/O perspective,
55+ independent of the internal representation of the model. For a model with `(m)` outputs, a
56+ `test_batch_shape x q x d`-shaped input to the model in eval mode returns a
57+ distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
5358 """
5459
5560 def __init__ (self , train_inputs , train_targets , likelihood ):
@@ -71,6 +76,17 @@ def __init__(self, train_inputs, train_targets, likelihood):
7176
7277 self .prediction_strategy = None
7378
79+ @property
80+ def batch_shape (self ) -> torch .Size :
81+ r"""The batch shape of the model.
82+
83+ This is a batch shape from an I/O perspective, independent of the internal
84+ representation of the model. For a model with `(m)` outputs, a
85+ `test_batch_shape x q x d`-shaped input to the model in eval mode returns a
86+ distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
87+ """
88+ return self .train_inputs [0 ].shape [:- 2 ]
89+
7490 @property
7591 def train_targets (self ):
7692 return self ._train_targets
@@ -160,8 +176,6 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
160176 "all test independent caches exist. Call the model on some data first!"
161177 )
162178
163- model_batch_shape = self .train_inputs [0 ].shape [:- 2 ]
164-
165179 if not isinstance (inputs , list ):
166180 inputs = [inputs ]
167181
@@ -184,17 +198,17 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
184198
185199 # Check whether we can properly broadcast batch dimensions
186200 try :
187- torch .broadcast_shapes (model_batch_shape , target_batch_shape )
201+ torch .broadcast_shapes (self . batch_shape , target_batch_shape )
188202 except RuntimeError :
189203 raise RuntimeError (
190- f"Model batch shape ({ model_batch_shape } ) and target batch shape "
204+ f"Model batch shape ({ self . batch_shape } ) and target batch shape "
191205 f"({ target_batch_shape } ) are not broadcastable."
192206 )
193207
194- if len (model_batch_shape ) > len (input_batch_shape ):
195- input_batch_shape = model_batch_shape
196- if len (model_batch_shape ) > len (target_batch_shape ):
197- target_batch_shape = model_batch_shape
208+ if len (self . batch_shape ) > len (input_batch_shape ):
209+ input_batch_shape = self . batch_shape
210+ if len (self . batch_shape ) > len (target_batch_shape ):
211+ target_batch_shape = self . batch_shape
198212
199213 # If input has no fantasy batch dimension but target does, we can save memory and computation by not
200214 # computing the covariance for each element of the batch. Therefore we don't expand the inputs to the
0 commit comments