Skip to content

Commit fb2cd26

Browse files
authored
Merge pull request #924 from alan-turing-institute/add_msll
Add MSLL metric
2 parents c7ea198 + cb35677 commit fb2cd26

File tree

5 files changed

+381
-33
lines changed

5 files changed

+381
-33
lines changed

autoemulate/core/compare.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,7 @@
1313

1414
from autoemulate.core.device import TorchDeviceMixin
1515
from autoemulate.core.logging_config import get_configured_logger
16-
from autoemulate.core.metrics import (
17-
R2,
18-
Metric,
19-
get_metric,
20-
get_metrics,
21-
)
16+
from autoemulate.core.metrics import R2, Metric, MetricParams, get_metric, get_metrics
2217
from autoemulate.core.model_selection import bootstrap, evaluate
2318
from autoemulate.core.plotting import (
2419
calculate_subplot_layout,
@@ -489,6 +484,7 @@ def compare(self):
489484
n_bootstraps=self.n_bootstraps,
490485
device=self.device,
491486
metrics=self.evaluation_metrics,
487+
metric_params=MetricParams(y_train=train_val_y),
492488
)
493489
test_metrics = bootstrap(
494490
transformed_emulator,
@@ -497,6 +493,7 @@ def compare(self):
497493
n_bootstraps=self.n_bootstraps,
498494
device=self.device,
499495
metrics=self.evaluation_metrics,
496+
metric_params=MetricParams(y_train=train_val_y),
500497
)
501498

502499
# Log all test metrics from test_metrics dictionary

autoemulate/core/metrics.py

Lines changed: 187 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44

55
from abc import abstractmethod
66
from collections.abc import Sequence
7+
from dataclasses import dataclass
78
from functools import partial, total_ordering
9+
from typing import Literal
810

911
import torch
1012
import torchmetrics
1113
from einops import rearrange
14+
from torch.distributions import Independent
1215
from torchmetrics.regression.crps import ContinuousRankedProbabilityScore
1316

1417
from autoemulate.core.types import (
@@ -19,6 +22,31 @@
1922
)
2023

2124

25+
@dataclass
26+
class MetricParams:
27+
"""
28+
Parameters for metric calculations.
29+
30+
Attributes
31+
----------
32+
n_samples: int
33+
Number of samples to draw from the predicted distribution if `y_pred` is a
34+
distribution. Defaults to 1000.
35+
y_train: TensorLike | None
36+
Training target values. In MSLL used to parameterize the trivial model for
37+
standardization. If None, mean log loss is computed without standardization.
38+
Defaults to None.
39+
reduction: Literal["mean", "none"]
40+
Reduction method to apply to the final metric scores computer per task.
41+
Options are 'mean' or 'none'. Defaults to 'mean'.
42+
"""
43+
44+
n_samples: int = 1000
45+
y_train: TensorLike | None = None
46+
reduction: Literal["mean", "none"] = "mean"
47+
metric_kwargs: dict | None = None # supports subclasses with arbitrary new kwargs
48+
49+
2250
@total_ordering
2351
class Metric:
2452
"""Configuration for a single metric.
@@ -58,7 +86,10 @@ def __lt__(self, other: Metric) -> bool:
5886

5987
@abstractmethod
6088
def __call__(
61-
self, y_pred: OutputLike, y_true: TensorLike, n_samples: int = 1000
89+
self,
90+
y_pred: OutputLike,
91+
y_true: TensorLike,
92+
metric_params: MetricParams | None = None,
6293
) -> TensorLike:
6394
"""Calculate metric."""
6495

@@ -87,20 +118,28 @@ def __init__(
87118
self.maximize = maximize
88119

89120
def __call__(
90-
self, y_pred: OutputLike, y_true: TensorLike, n_samples: int = 1000
121+
self,
122+
y_pred: OutputLike,
123+
y_true: TensorLike,
124+
metric_params: MetricParams | None = None,
91125
) -> TensorLike:
92126
"""Calculate metric."""
93127
if not isinstance(y_pred, OutputLike):
94128
raise ValueError(f"Metric not implemented for y_pred ({type(y_pred)})")
95129
if not isinstance(y_true, TensorLike):
96130
raise ValueError(f"Metric not implemented for y_true ({type(y_true)})")
97131

132+
if metric_params is None:
133+
metric_params = MetricParams()
134+
98135
# Handle probabilistic predictions
99136
if isinstance(y_pred, DistributionLike):
100137
try:
101138
y_pred = y_pred.mean
102139
except Exception:
103-
y_pred = y_pred.rsample((n_samples,)).mean(dim=0)
140+
y_pred = y_pred.rsample(torch.Size([metric_params.n_samples])).mean(
141+
dim=0
142+
)
104143
metric = self.metric()
105144
metric.to(y_pred.device)
106145

@@ -117,7 +156,10 @@ class ProbabilisticMetric(Metric):
117156

118157
@abstractmethod
119158
def __call__(
120-
self, y_pred: OutputLike, y_true: TensorLike, n_samples: int = 1000
159+
self,
160+
y_pred: OutputLike,
161+
y_true: TensorLike,
162+
metric_params: MetricParams | None = None,
121163
) -> TensorLike:
122164
"""Calculate metric."""
123165

@@ -145,7 +187,10 @@ class CRPSMetric(ProbabilisticMetric):
145187
maximize: bool = False
146188

147189
def __call__(
148-
self, y_pred: OutputLike, y_true: TensorLike, n_samples: int = 1000
190+
self,
191+
y_pred: OutputLike,
192+
y_true: TensorLike,
193+
metric_params: MetricParams | None = None,
149194
) -> TensorLike:
150195
"""Calculate CRPS metric.
151196
@@ -167,9 +212,8 @@ def __call__(
167212
- If distribution: `n_samples` are drawn to estimate CRPS.
168213
y_true: TensorLike
169214
True target values of shape `(batch_size, *target_shape)`.
170-
n_samples: int
171-
Number of samples to draw from the predicted distribution if `y_pred` is a
172-
distribution. Defaults to 1000.
215+
metric_params: MetricParams
216+
Metric parameters including: n_samples.
173217
174218
Returns
175219
-------
@@ -184,6 +228,9 @@ def __call__(
184228
if not isinstance(y_true, TensorLike):
185229
raise ValueError(f"y_true must be a tensor, got {type(y_true)}")
186230

231+
if metric_params is None:
232+
metric_params = MetricParams()
233+
187234
# Ensure 2D y_true for consistent handling
188235
y_true = y_true.unsqueeze(-1) if y_true.ndim == 1 else y_true
189236

@@ -195,7 +242,7 @@ def __call__(
195242
if isinstance(y_pred, DistributionLike):
196243
# Distribution case: sample from it
197244
samples = rearrange(
198-
y_pred.sample(torch.Size((n_samples,))),
245+
y_pred.sample(torch.Size((metric_params.n_samples,))),
199246
"s b ... -> b ... s",
200247
)
201248
if samples.shape[:-1] != y_true.shape:
@@ -236,6 +283,134 @@ def __call__(
236283
return crps_metric(samples_flat, y_true_flat)
237284

238285

286+
class MSLLMetric(ProbabilisticMetric):
287+
"""Mean Standardized Log Loss (MSLL) metric.
288+
289+
MSLL evaluates the quality of probabilistic predictions by measuring the
290+
log-likelihood of the true values under the predictive distribution,
291+
standardized by the log-likelihood under the trivial model (i.e., predictive
292+
normal distribution parameterized with the data mean and variance).
293+
294+
If no training data is supplied, the mean log loss is computed.
295+
296+
Lower MSLL values indicate better predictive performance.
297+
298+
Note: This metric requires probabilistic predictions. Standardization
299+
assumes that the predictive distribution is Gaussian.
300+
301+
Attributes
302+
----------
303+
name: str
304+
Display name for the metric.
305+
maximize: bool
306+
Whether higher values are better. False for MSLL (lower is better).
307+
"""
308+
309+
name: str = "msll"
310+
maximize: bool = False
311+
312+
def __call__(
313+
self,
314+
y_pred: OutputLike,
315+
y_true: TensorLike,
316+
metric_params: MetricParams | None = None,
317+
) -> TensorLike:
318+
"""Calculate MSLL metric.
319+
320+
If no training data is provided in `metric_params.y_train`, the mean log loss
321+
is computed without standardization.
322+
323+
Parameters
324+
----------
325+
y_pred: OutputLike
326+
Predicted outputs. Must be a distribution.
327+
y_true: TensorLike
328+
True target values.
329+
metric_params: MetricParams
330+
Metric parameters including: y_train and reduction.
331+
332+
Returns
333+
-------
334+
TensorLike
335+
Mean Standardized Log Loss (MSLL) score.
336+
337+
Raises
338+
------
339+
ValueError
340+
If y_pred is not a distribution.
341+
"""
342+
if not isinstance(y_pred, DistributionLike):
343+
raise ValueError(
344+
f"MSLL metric requires probabilistic predictions, got {type(y_pred)}. "
345+
)
346+
347+
if not isinstance(y_true, TensorLike):
348+
raise ValueError(f"y_true must be a tensor, got {type(y_true)}")
349+
350+
if metric_params is None:
351+
metric_params = MetricParams()
352+
353+
# Ensure 2D y_true for consistent handling
354+
y_true = y_true.unsqueeze(-1) if y_true.ndim == 1 else y_true
355+
356+
# Compute mean negative log likelihood (also by output dimension if have
357+
# Independent distribution to support 'none' reduction)
358+
if isinstance(y_pred, Independent):
359+
model_nll_output = -y_pred.base_dist.log_prob(y_true).mean(dim=0)
360+
model_nll_total = model_nll_output.mean()
361+
else:
362+
model_nll_output = None
363+
model_nll_total = -y_pred.log_prob(y_true).mean()
364+
365+
# If no training data, return mean log loss
366+
if metric_params.y_train is None:
367+
if metric_params.reduction == "mean":
368+
return model_nll_total
369+
if metric_params.reduction == "none":
370+
if model_nll_output is None:
371+
msg = (
372+
"Per-output MLL not available for non-Independent "
373+
"distributions."
374+
)
375+
raise ValueError(msg)
376+
return model_nll_output.reshape(*y_true.shape[1:])
377+
msg = (
378+
f"Unknown reduction method: {metric_params.reduction}. "
379+
"Expected 'mean' or 'none'."
380+
)
381+
raise ValueError(msg)
382+
383+
# Keep original shape for y_train_mean to match y_true shape
384+
y_train_mean = metric_params.y_train.mean(dim=0, keepdim=True)
385+
386+
# following GPyTorch implementation, use global variance rather than per task
387+
# https://github.com/cornellius-gp/gpytorch/blob/c0fb6c64311fdbef2862fd3ba2bd613fbd081e79/gpytorch/metrics/metrics.py#L60
388+
y_train_var = metric_params.y_train.var()
389+
390+
# Avoid numerical issues
391+
y_train_var = torch.clamp(y_train_var, min=1e-6)
392+
393+
# Compute mean negative log likelihood under trivial Gaussian model
394+
trivial_nll_output = 0.5 * (
395+
torch.log(2 * torch.pi * y_train_var)
396+
+ torch.square(y_true - y_train_mean) / (2 * y_train_var)
397+
).mean(dim=0)
398+
399+
# Return mean standardized log loss
400+
if metric_params.reduction == "mean":
401+
return model_nll_total - trivial_nll_output.mean()
402+
if metric_params.reduction == "none":
403+
if model_nll_output is None:
404+
msg = "Per-output MLL not available for non-Independent distributions."
405+
raise ValueError(msg)
406+
return (model_nll_output - trivial_nll_output).reshape(*y_true.shape[1:])
407+
msg = (
408+
f"Unknown reduction method: {metric_params.reduction}. "
409+
"Expected 'mean' or 'none'."
410+
)
411+
raise ValueError(msg)
412+
413+
239414
R2 = TorchMetrics(
240415
metric=torchmetrics.R2Score,
241416
name="r2",
@@ -262,12 +437,15 @@ def __call__(
262437

263438
CRPS = CRPSMetric()
264439

440+
MSLL = MSLLMetric()
441+
265442
AVAILABLE_METRICS = {
266443
"r2": R2,
267444
"rmse": RMSE,
268445
"mse": MSE,
269446
"mae": MAE,
270447
"crps": CRPS,
448+
"msll": MSLL,
271449
}
272450

273451

0 commit comments

Comments
 (0)