Skip to content

Commit 0abb6f2

Browse files
committed
update MSLL implementation
1 parent 87ebca3 commit 0abb6f2

File tree

1 file changed

+16
-43
lines changed

1 file changed

+16
-43
lines changed

autoemulate/core/metrics.py

Lines changed: 16 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -287,13 +287,14 @@ class MSLLMetric(ProbabilisticMetric):
287287
MSLL evaluates the quality of probabilistic predictions by measuring the
288288
log-likelihood of the true values under the predictive distribution,
289289
standardized by the log-likelihood under the trivial model (i.e., predictive
290-
distribution parameterized with the data mean and variance).
290+
normal distribution parameterized with the data mean and variance).
291291
292292
If no training data is supplied, the mean log loss is computed.
293293
294294
Lower MSLL values indicate better predictive performance.
295295
296-
Note: This metric requires probabilistic predictions.
296+
Note: This metric requires probabilistic predictions. Standardization
297+
assumes that the predictive distribution is Gaussian.
297298
298299
Attributes
299300
----------
@@ -324,7 +325,7 @@ def __call__(
324325
y_true: TensorLike
325326
True target values.
326327
metric_params: MetricParams
327-
Metric parameters including: n_samples, y_train, reduction.
328+
Metric parameters including: y_train.
328329
329330
Returns
330331
-------
@@ -350,37 +351,12 @@ def __call__(
350351
# Ensure 2D y_true for consistent handling
351352
y_true = y_true.unsqueeze(-1) if y_true.ndim == 1 else y_true
352353

353-
# Handle distributions without mean/variance attributes
354-
try:
355-
y_pred_mean, y_pred_var = y_pred.mean, y_pred.variance
356-
except Exception:
357-
y_pred_samples = y_pred.rsample(torch.Size([metric_params.n_samples]))
358-
y_pred_mean = y_pred_samples.mean(dim=0)
359-
y_pred_var = y_pred_samples.var(dim=0)
360-
361-
if y_pred_mean.shape != y_true.shape:
362-
raise ValueError(
363-
f"Predictions shape {y_pred_mean.shape} does not match "
364-
f"y_true shape {y_true.shape}."
365-
)
366-
367354
# Compute mean log loss
368-
mean_log_loss = (
369-
0.5 * torch.log(2 * torch.pi * y_pred_var)
370-
+ torch.square(y_true - y_pred_mean) / (2 * y_pred_var)
371-
).mean(dim=0)
355+
mean_log_loss = y_pred.log_prob(y_true).mean(dim=0)
372356

373357
# If no training data, return mean log loss
374358
if metric_params.y_train is None:
375-
if metric_params.reduction == "none":
376-
return mean_log_loss
377-
if metric_params.reduction == "mean":
378-
return mean_log_loss.mean()
379-
msg = (
380-
f"Invalid reduction '{metric_params.reduction}'. "
381-
"Expected 'mean' or 'none'."
382-
)
383-
raise ValueError(msg)
359+
return mean_log_loss
384360

385361
# Ensure 2D y_train for consistent handling
386362
y_train = (
@@ -389,7 +365,8 @@ def __call__(
389365
else metric_params.y_train
390366
)
391367

392-
y_train_mean = y_train.mean(dim=0)
368+
y_train_mean = y_train.mean()
369+
393370
# following GPyTorch implementation, use global variance rather than per task
394371
# https://github.com/cornellius-gp/gpytorch/blob/c0fb6c64311fdbef2862fd3ba2bd613fbd081e79/gpytorch/metrics/metrics.py#L60
395372
y_train_var = y_train.var()
@@ -398,20 +375,16 @@ def __call__(
398375
y_train_var = torch.clamp(y_train_var, min=1e-6)
399376

400377
# Compute mean log prob under trivial Gaussian model
401-
mean_trivial_log_loss = 0.5 * (
402-
torch.log(2 * torch.pi * y_train_var)
403-
+ torch.square(y_true - y_train_mean) / (2 * y_train_var)
404-
).mean(dim=0)
378+
mean_trivial_log_loss = (
379+
0.5
380+
* (
381+
torch.log(2 * torch.pi * y_train_var)
382+
+ torch.square(y_true - y_train_mean) / (2 * y_train_var)
383+
).mean()
384+
)
405385

406386
# Return mean standardized log loss
407-
if metric_params.reduction == "none":
408-
return mean_log_loss - mean_trivial_log_loss
409-
if metric_params.reduction == "mean":
410-
return (mean_log_loss - mean_trivial_log_loss).mean()
411-
msg = (
412-
f"Invalid reduction '{metric_params.reduction}'. Expected 'mean' or 'none'."
413-
)
414-
raise ValueError(msg)
387+
return mean_log_loss - mean_trivial_log_loss
415388

416389

417390
R2 = TorchMetrics(

0 commit comments

Comments
 (0)