44
55from abc import abstractmethod
66from collections .abc import Sequence
7+ from dataclasses import dataclass
78from functools import partial , total_ordering
9+ from typing import Literal
810
911import torch
1012import torchmetrics
1113from einops import rearrange
14+ from torch .distributions import Independent
1215from torchmetrics .regression .crps import ContinuousRankedProbabilityScore
1316
1417from autoemulate .core .types import (
1922)
2023
2124
25+ @dataclass
26+ class MetricParams :
27+ """
28+ Parameters for metric calculations.
29+
30+ Attributes
31+ ----------
32+ n_samples: int
33+ Number of samples to draw from the predicted distribution if `y_pred` is a
34+ distribution. Defaults to 1000.
35+ y_train: TensorLike | None
36+ Training target values. In MSLL used to parameterize the trivial model for
37+ standardization. If None, mean log loss is computed without standardization.
38+ Defaults to None.
39+ reduction: Literal["mean", "none"]
40+ Reduction method to apply to the final metric scores computer per task.
41+ Options are 'mean' or 'none'. Defaults to 'mean'.
42+ """
43+
44+ n_samples : int = 1000
45+ y_train : TensorLike | None = None
46+ reduction : Literal ["mean" , "none" ] = "mean"
47+ metric_kwargs : dict | None = None # supports subclasses with arbitrary new kwargs
48+
49+
2250@total_ordering
2351class Metric :
2452 """Configuration for a single metric.
@@ -58,7 +86,10 @@ def __lt__(self, other: Metric) -> bool:
5886
5987 @abstractmethod
6088 def __call__ (
61- self , y_pred : OutputLike , y_true : TensorLike , n_samples : int = 1000
89+ self ,
90+ y_pred : OutputLike ,
91+ y_true : TensorLike ,
92+ metric_params : MetricParams | None = None ,
6293 ) -> TensorLike :
6394 """Calculate metric."""
6495
@@ -87,20 +118,28 @@ def __init__(
87118 self .maximize = maximize
88119
89120 def __call__ (
90- self , y_pred : OutputLike , y_true : TensorLike , n_samples : int = 1000
121+ self ,
122+ y_pred : OutputLike ,
123+ y_true : TensorLike ,
124+ metric_params : MetricParams | None = None ,
91125 ) -> TensorLike :
92126 """Calculate metric."""
93127 if not isinstance (y_pred , OutputLike ):
94128 raise ValueError (f"Metric not implemented for y_pred ({ type (y_pred )} )" )
95129 if not isinstance (y_true , TensorLike ):
96130 raise ValueError (f"Metric not implemented for y_true ({ type (y_true )} )" )
97131
132+ if metric_params is None :
133+ metric_params = MetricParams ()
134+
98135 # Handle probabilistic predictions
99136 if isinstance (y_pred , DistributionLike ):
100137 try :
101138 y_pred = y_pred .mean
102139 except Exception :
103- y_pred = y_pred .rsample ((n_samples ,)).mean (dim = 0 )
140+ y_pred = y_pred .rsample (torch .Size ([metric_params .n_samples ])).mean (
141+ dim = 0
142+ )
104143 metric = self .metric ()
105144 metric .to (y_pred .device )
106145
@@ -117,7 +156,10 @@ class ProbabilisticMetric(Metric):
117156
118157 @abstractmethod
119158 def __call__ (
120- self , y_pred : OutputLike , y_true : TensorLike , n_samples : int = 1000
159+ self ,
160+ y_pred : OutputLike ,
161+ y_true : TensorLike ,
162+ metric_params : MetricParams | None = None ,
121163 ) -> TensorLike :
122164 """Calculate metric."""
123165
@@ -145,7 +187,10 @@ class CRPSMetric(ProbabilisticMetric):
145187 maximize : bool = False
146188
147189 def __call__ (
148- self , y_pred : OutputLike , y_true : TensorLike , n_samples : int = 1000
190+ self ,
191+ y_pred : OutputLike ,
192+ y_true : TensorLike ,
193+ metric_params : MetricParams | None = None ,
149194 ) -> TensorLike :
150195 """Calculate CRPS metric.
151196
@@ -167,9 +212,8 @@ def __call__(
167212 - If distribution: `n_samples` are drawn to estimate CRPS.
168213 y_true: TensorLike
169214 True target values of shape `(batch_size, *target_shape)`.
170- n_samples: int
171- Number of samples to draw from the predicted distribution if `y_pred` is a
172- distribution. Defaults to 1000.
215+ metric_params: MetricParams
216+ Metric parameters including: n_samples.
173217
174218 Returns
175219 -------
@@ -184,6 +228,9 @@ def __call__(
184228 if not isinstance (y_true , TensorLike ):
185229 raise ValueError (f"y_true must be a tensor, got { type (y_true )} " )
186230
231+ if metric_params is None :
232+ metric_params = MetricParams ()
233+
187234 # Ensure 2D y_true for consistent handling
188235 y_true = y_true .unsqueeze (- 1 ) if y_true .ndim == 1 else y_true
189236
@@ -195,7 +242,7 @@ def __call__(
195242 if isinstance (y_pred , DistributionLike ):
196243 # Distribution case: sample from it
197244 samples = rearrange (
198- y_pred .sample (torch .Size ((n_samples ,))),
245+ y_pred .sample (torch .Size ((metric_params . n_samples ,))),
199246 "s b ... -> b ... s" ,
200247 )
201248 if samples .shape [:- 1 ] != y_true .shape :
@@ -236,6 +283,134 @@ def __call__(
236283 return crps_metric (samples_flat , y_true_flat )
237284
238285
286+ class MSLLMetric (ProbabilisticMetric ):
287+ """Mean Standardized Log Loss (MSLL) metric.
288+
289+ MSLL evaluates the quality of probabilistic predictions by measuring the
290+ log-likelihood of the true values under the predictive distribution,
291+ standardized by the log-likelihood under the trivial model (i.e., predictive
292+ normal distribution parameterized with the data mean and variance).
293+
294+ If no training data is supplied, the mean log loss is computed.
295+
296+ Lower MSLL values indicate better predictive performance.
297+
298+ Note: This metric requires probabilistic predictions. Standardization
299+ assumes that the predictive distribution is Gaussian.
300+
301+ Attributes
302+ ----------
303+ name: str
304+ Display name for the metric.
305+ maximize: bool
306+ Whether higher values are better. False for MSLL (lower is better).
307+ """
308+
309+ name : str = "msll"
310+ maximize : bool = False
311+
312+ def __call__ (
313+ self ,
314+ y_pred : OutputLike ,
315+ y_true : TensorLike ,
316+ metric_params : MetricParams | None = None ,
317+ ) -> TensorLike :
318+ """Calculate MSLL metric.
319+
320+ If no training data is provided in `metric_params.y_train`, the mean log loss
321+ is computed without standardization.
322+
323+ Parameters
324+ ----------
325+ y_pred: OutputLike
326+ Predicted outputs. Must be a distribution.
327+ y_true: TensorLike
328+ True target values.
329+ metric_params: MetricParams
330+ Metric parameters including: y_train and reduction.
331+
332+ Returns
333+ -------
334+ TensorLike
335+ Mean Standardized Log Loss (MSLL) score.
336+
337+ Raises
338+ ------
339+ ValueError
340+ If y_pred is not a distribution.
341+ """
342+ if not isinstance (y_pred , DistributionLike ):
343+ raise ValueError (
344+ f"MSLL metric requires probabilistic predictions, got { type (y_pred )} . "
345+ )
346+
347+ if not isinstance (y_true , TensorLike ):
348+ raise ValueError (f"y_true must be a tensor, got { type (y_true )} " )
349+
350+ if metric_params is None :
351+ metric_params = MetricParams ()
352+
353+ # Ensure 2D y_true for consistent handling
354+ y_true = y_true .unsqueeze (- 1 ) if y_true .ndim == 1 else y_true
355+
356+ # Compute mean negative log likelihood (also by output dimension if have
357+ # Independent distribution to support 'none' reduction)
358+ if isinstance (y_pred , Independent ):
359+ model_nll_output = - y_pred .base_dist .log_prob (y_true ).mean (dim = 0 )
360+ model_nll_total = model_nll_output .mean ()
361+ else :
362+ model_nll_output = None
363+ model_nll_total = - y_pred .log_prob (y_true ).mean ()
364+
365+ # If no training data, return mean log loss
366+ if metric_params .y_train is None :
367+ if metric_params .reduction == "mean" :
368+ return model_nll_total
369+ if metric_params .reduction == "none" :
370+ if model_nll_output is None :
371+ msg = (
372+ "Per-output MLL not available for non-Independent "
373+ "distributions."
374+ )
375+ raise ValueError (msg )
376+ return model_nll_output .reshape (* y_true .shape [1 :])
377+ msg = (
378+ f"Unknown reduction method: { metric_params .reduction } . "
379+ "Expected 'mean' or 'none'."
380+ )
381+ raise ValueError (msg )
382+
383+ # Keep original shape for y_train_mean to match y_true shape
384+ y_train_mean = metric_params .y_train .mean (dim = 0 , keepdim = True )
385+
386+ # following GPyTorch implementation, use global variance rather than per task
387+ # https://github.com/cornellius-gp/gpytorch/blob/c0fb6c64311fdbef2862fd3ba2bd613fbd081e79/gpytorch/metrics/metrics.py#L60
388+ y_train_var = metric_params .y_train .var ()
389+
390+ # Avoid numerical issues
391+ y_train_var = torch .clamp (y_train_var , min = 1e-6 )
392+
393+ # Compute mean negative log likelihood under trivial Gaussian model
394+ trivial_nll_output = 0.5 * (
395+ torch .log (2 * torch .pi * y_train_var )
396+ + torch .square (y_true - y_train_mean ) / (2 * y_train_var )
397+ ).mean (dim = 0 )
398+
399+ # Return mean standardized log loss
400+ if metric_params .reduction == "mean" :
401+ return model_nll_total - trivial_nll_output .mean ()
402+ if metric_params .reduction == "none" :
403+ if model_nll_output is None :
404+ msg = "Per-output MLL not available for non-Independent distributions."
405+ raise ValueError (msg )
406+ return (model_nll_output - trivial_nll_output ).reshape (* y_true .shape [1 :])
407+ msg = (
408+ f"Unknown reduction method: { metric_params .reduction } . "
409+ "Expected 'mean' or 'none'."
410+ )
411+ raise ValueError (msg )
412+
413+
239414R2 = TorchMetrics (
240415 metric = torchmetrics .R2Score ,
241416 name = "r2" ,
@@ -262,12 +437,15 @@ def __call__(
262437
263438CRPS = CRPSMetric ()
264439
440+ MSLL = MSLLMetric ()
441+
265442AVAILABLE_METRICS = {
266443 "r2" : R2 ,
267444 "rmse" : RMSE ,
268445 "mse" : MSE ,
269446 "mae" : MAE ,
270447 "crps" : CRPS ,
448+ "msll" : MSLL ,
271449}
272450
273451
0 commit comments