Skip to content

Commit 597977d

Browse files
committed
raise error if output distribution does not have a variance property
1 parent 40d7530 commit 597977d

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

autoemulate/learners/base.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,16 @@ def fit(self, *args):
194194
self.metrics["n_queries"].append(self.n_queries)
195195
self.logger.info("Metrics updated: MSE=%s, R2=%s", mse_val, r2_val)
196196

197-
# If Gaussian output
197+
# If distribution output
198198
# TODO: check generality for other GPs (e.g. with full covariance)
199199
if isinstance(output, DistributionLike):
200+
if not hasattr(output, "variance"):
201+
msg = (
202+
f"Output of type {type(output)} does not have a 'variance'"
203+
"property. This may occur if output is a PyTorch "
204+
"TransformedDistribution."
205+
)
206+
raise AttributeError(msg)
200207
assert isinstance(output.variance, TensorLike)
201208
assert output.variance.ndim == 2
202209
assert output.variance.shape[1] == self.out_dim

0 commit comments

Comments
 (0)