@@ -71,6 +71,17 @@ def __init__(self, train_inputs, train_targets, likelihood):
7171
7272 self .prediction_strategy = None
7373
74+ @property
75+ def batch_shape (self ) -> torch .Size :
76+ r"""The batch shape of the model.
77+
78+ This is a batch shape from an I/O perspective, independent of the internal
79+ representation of the model. For a model with `(m)` outputs, a
80+ `test_batch_shape x q x d`-shaped input to the model in eval mode returns a
81+ distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
82+ """
83+ return self .train_inputs [0 ].shape [:- 2 ]
84+
7485 @property
7586 def train_targets (self ):
7687 return self ._train_targets
@@ -160,9 +171,7 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
160171 "all test independent caches exist. Call the model on some data first!"
161172 )
162173
163- model_batch_shape = self .train_inputs [0 ].shape [:- 2 ]
164-
165- if self .train_targets .dim () > len (model_batch_shape ) + 1 :
174+ if self .train_targets .dim () > len (self .batch_shape ) + 1 :
166175 raise RuntimeError ("Cannot yet add fantasy observations to multitask GPs, but this is coming soon!" )
167176
168177 if not isinstance (inputs , list ):
@@ -182,17 +191,17 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
182191
183192 # Check whether we can properly broadcast batch dimensions
184193 try :
185- torch .broadcast_shapes (model_batch_shape , target_batch_shape )
194+ torch .broadcast_shapes (self . batch_shape , target_batch_shape )
186195 except RuntimeError :
187196 raise RuntimeError (
188- f"Model batch shape ({ model_batch_shape } ) and target batch shape "
197+ f"Model batch shape ({ self . batch_shape } ) and target batch shape "
189198 f"({ target_batch_shape } ) are not broadcastable."
190199 )
191200
192- if len (model_batch_shape ) > len (input_batch_shape ):
193- input_batch_shape = model_batch_shape
194- if len (model_batch_shape ) > len (target_batch_shape ):
195- target_batch_shape = model_batch_shape
201+ if len (self . batch_shape ) > len (input_batch_shape ):
202+ input_batch_shape = self . batch_shape
203+ if len (self . batch_shape ) > len (target_batch_shape ):
204+ target_batch_shape = self . batch_shape
196205
197206 # If input has no fantasy batch dimension but target does, we can save memory and computation by not
198207 # computing the covariance for each element of the batch. Therefore we don't expand the inputs to the
0 commit comments