Skip to content

Commit 93d0b50

Browse files
committed
add support for MSLL per output if have independent predictions
1 parent ad172ae commit 93d0b50

File tree

2 files changed

+58
-13
lines changed

2 files changed

+58
-13
lines changed

autoemulate/core/metrics.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
import torchmetrics
1313
from einops import rearrange
14+
from torch.distributions import Independent
1415
from torchmetrics.regression.crps import ContinuousRankedProbabilityScore
1516

1617
from autoemulate.core.types import (
@@ -351,12 +352,32 @@ def __call__(
351352
# Ensure 2D y_true for consistent handling
352353
y_true = y_true.unsqueeze(-1) if y_true.ndim == 1 else y_true
353354

354-
# Compute mean negative log likelihood
355-
mean_log_loss = -y_pred.log_prob(y_true).mean()
355+
# Compute mean negative log likelihood (also by output dimension if have
356+
# Independent distribution)
357+
if isinstance(y_pred, Independent):
358+
model_nll_output = -y_pred.base_dist.log_prob(y_true).mean(dim=0)
359+
model_nll_total = model_nll_output.mean()
360+
else:
361+
model_nll_output = None
362+
model_nll_total = -y_pred.log_prob(y_true).mean()
356363

357364
# If no training data, return mean log loss
358365
if metric_params.y_train is None:
359-
return mean_log_loss
366+
if metric_params.reduction == "mean":
367+
return model_nll_total
368+
if metric_params.reduction == "none":
369+
if model_nll_output is None:
370+
msg = (
371+
"Per-output MLL not available for non-Independent "
372+
"distributions."
373+
)
374+
raise ValueError(msg)
375+
return model_nll_output
376+
msg = (
377+
f"Unknown reduction method: {metric_params.reduction}. "
378+
"Expected 'mean' or 'none'."
379+
)
380+
raise ValueError(msg)
360381

361382
# Ensure 2D y_train for consistent handling
362383
y_train_mean = metric_params.y_train.mean(dim=0, keepdim=True).view(1, -1)
@@ -369,16 +390,24 @@ def __call__(
369390
y_train_var = torch.clamp(y_train_var, min=1e-6)
370391

371392
# Compute mean negative log likelihood under trivial Gaussian model
372-
mean_trivial_log_loss = (
373-
0.5
374-
* (
375-
torch.log(2 * torch.pi * y_train_var)
376-
+ torch.square(y_true - y_train_mean) / (2 * y_train_var)
377-
).mean()
378-
)
393+
trivial_nll_output = 0.5 * (
394+
torch.log(2 * torch.pi * y_train_var)
395+
+ torch.square(y_true - y_train_mean) / (2 * y_train_var)
396+
).mean(dim=0)
379397

380398
# Return mean standardized log loss
381-
return mean_log_loss - mean_trivial_log_loss
399+
if metric_params.reduction == "mean":
400+
return model_nll_total - trivial_nll_output.mean()
401+
if metric_params.reduction == "none":
402+
if model_nll_output is None:
403+
msg = "Per-output MLL not available for non-Independent distributions."
404+
raise ValueError(msg)
405+
return model_nll_output - trivial_nll_output
406+
msg = (
407+
f"Unknown reduction method: {metric_params.reduction}. "
408+
"Expected 'mean' or 'none'."
409+
)
410+
raise ValueError(msg)
382411

383412

384413
R2 = TorchMetrics(

tests/core/test_metrics.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
get_metric,
2222
get_metrics,
2323
)
24-
from torch.distributions import Normal
24+
from torch.distributions import Independent, Normal
2525

2626
# Tests for the base Metric class
2727

@@ -578,10 +578,26 @@ def test_msll_with_2d_inputs():
578578
y_true,
579579
metric_params=MetricParams(y_train=y_train),
580580
)
581-
582581
assert isinstance(msll, torch.Tensor)
583582
assert msll.shape == torch.Size([])
584583

584+
# can't compuate per-output MSLL for non-Independent distributions
585+
with pytest.raises(ValueError, match="Per-output MLL not available"):
586+
msll = MSLL(
587+
y_pred,
588+
y_true,
589+
metric_params=MetricParams(y_train=y_train, reduction="none"),
590+
)
591+
592+
y_pred = Independent(Normal(loc=y_true, scale=torch.ones_like(y_true)), 1)
593+
msll = MSLL(
594+
y_pred,
595+
y_true,
596+
metric_params=MetricParams(y_train=y_train, reduction="none"),
597+
)
598+
assert isinstance(msll, torch.Tensor)
599+
assert msll.shape == torch.Size([3])
600+
585601

586602
def test_msll_raises_for_non_distribution():
587603
"""Test MSLL raises ValueError for non-distribution predictions."""

0 commit comments

Comments
 (0)