77from functools import partial
88
99import torchmetrics
10-
11- from autoemulate .core .types import OutputLike , TensorLike , TorchMetricsLike
10+ from einops import rearrange
11+ from torchmetrics .regression .crps import ContinuousRankedProbabilityScore
12+
13+ from autoemulate .core .types import (
14+ DistributionLike ,
15+ OutputLike ,
16+ TensorLike ,
17+ TorchMetricsLike ,
18+ )
1219
1320
1421class Metric :
1522 """Configuration for a single metric.
1623
1724 Parameters
1825 ----------
19- name : str
26+ name: str
2027 Display name for the metric.
21- maximize : bool
28+ maximize: bool
2229 Whether higher values are better. Defaults to True.
2330 """
2431
2532 name : str
2633 maximize : bool
2734
2835 def __repr__ (self ) -> str :
29- """Return the string representation of the MetricConfig ."""
30- return f"MetricConfig (name={ self .name } , maximize={ self .maximize } )"
36+ """Return the string representation of the Metric ."""
37+ return f"Metric (name={ self .name } , maximize={ self .maximize } )"
3138
3239 @abstractmethod
33- def __call__ (self , y_pred : OutputLike , y_true : TensorLike ) -> TensorLike :
40+ def __call__ (
41+ self , y_pred : OutputLike , y_true : TensorLike , n_samples : int = 1000
42+ ) -> TensorLike :
3443 """Calculate metric."""
3544
3645
@@ -57,20 +66,149 @@ def __init__(
5766 self .name = name
5867 self .maximize = maximize
5968
60- def __call__ (self , y_pred : OutputLike , y_true : TensorLike ) -> TensorLike :
69+ def __call__ (
70+ self , y_pred : OutputLike , y_true : TensorLike , n_samples : int = 1000
71+ ) -> TensorLike :
6172 """Calculate metric."""
62- if not isinstance (y_pred , TensorLike ):
73+ if not isinstance (y_pred , OutputLike ):
6374 raise ValueError (f"Metric not implemented for y_pred ({ type (y_pred )} )" )
6475 if not isinstance (y_true , TensorLike ):
6576 raise ValueError (f"Metric not implemented for y_true ({ type (y_true )} )" )
6677
78+ # Handle probabilistic predictions
79+ if isinstance (y_pred , DistributionLike ):
80+ try :
81+ y_pred = y_pred .mean
82+ except Exception :
83+ y_pred = y_pred .rsample ((n_samples ,)).mean (dim = 0 )
6784 metric = self .metric ()
6885 metric .to (y_pred .device )
6986 # Assume first dim is a batch dim, flatten others for metric calculation
7087 metric .update (y_pred .flatten (start_dim = 1 ), y_true .flatten (start_dim = 1 ))
7188 return metric .compute ()
7289
7390
91+ class ProbabilisticMetric (Metric ):
92+ """Base class for probabilistic metrics."""
93+
94+ @abstractmethod
95+ def __call__ (
96+ self , y_pred : OutputLike , y_true : TensorLike , n_samples : int = 1000
97+ ) -> TensorLike :
98+ """Calculate metric."""
99+
100+
101+ class CRPSMetric (ProbabilisticMetric ):
102+ """Continuous Ranked Probability Score (CRPS) metric.
103+
104+ CRPS is a scoring rule for evaluating probabilistic predictions. It reduces to mean
105+ absolute error (MAE) for deterministic predictions and generalizes to distributions
106+ by measuring the integral difference between predicted and actual CDFs.
107+
108+ The metric aggregates over batch and target dimensions by computing the mean
109+ CRPS across all scalar outputs, making it comparable across different batch
110+ sizes and output dimensions.
111+
112+ Attributes
113+ ----------
114+ name: str
115+ Display name for the metric.
116+ maximize: bool
117+ Whether higher values are better. False for CRPS (lower is better).
118+ """
119+
120+ name : str = "crps"
121+ maximize : bool = False
122+
123+ def __call__ (
124+ self , y_pred : OutputLike , y_true : TensorLike , n_samples : int = 1000
125+ ) -> TensorLike :
126+ """Calculate CRPS metric.
127+
128+ The metric handles both deterministic predictions (tensors) and probabilistic
129+ predictions (tensors of samples or distributions).
130+
131+ Aggregation across batch and target dimensions is performed by computing the
132+ mean CRPS across all scalar outputs. This makes the metric comparable across
133+ different batch sizes and target dimensions.
134+
135+ Parameters
136+ ----------
137+ y_pred: OutputLike
138+ Predicted outputs. Can be a tensor or a distribution.
139+ - If tensor with shape `(batch_size, *target_shape)`: treated as
140+ deterministic prediction (reduces to MAE).
141+ - If tensor with shape `(batch_size, *target_shape, n_samples)`: treated as
142+ samples from a probabilistic prediction.
143+ - If distribution: `n_samples` are drawn to estimate CRPS.
144+ y_true: TensorLike
145+ True target values of shape `(batch_size, *target_shape)`.
146+ n_samples: int
147+ Number of samples to draw from the predicted distribution if `y_pred` is a
148+ distribution. Defaults to 1000.
149+
150+ Returns
151+ -------
152+ TensorLike
153+ Mean CRPS score across all batch elements and target dimensions.
154+
155+ Raises
156+ ------
157+ ValueError
158+ If input types or shapes are incompatible.
159+ """
160+ if not isinstance (y_true , TensorLike ):
161+ raise ValueError (f"y_true must be a tensor, got { type (y_true )} " )
162+
163+ # Ensure 2D y_true for consistent handling
164+ y_true = y_true .unsqueeze (- 1 ) if y_true .ndim == 1 else y_true
165+
166+ # Initialize CRPS metric (computes mean by default)
167+ crps_metric = ContinuousRankedProbabilityScore ()
168+ crps_metric .to (y_true .device )
169+
170+ # Handle different prediction types
171+ if isinstance (y_pred , DistributionLike ):
172+ # Distribution case: sample from it
173+ samples = rearrange (y_pred .sample ((n_samples ,)), "s b ... -> b ... s" )
174+ if samples .shape [:- 1 ] != y_true .shape :
175+ raise ValueError (
176+ f"Sampled predictions shape { samples .shape [:- 1 ]} (excluding sample "
177+ f"dimension) does not match y_true shape { y_true .shape } "
178+ )
179+ elif isinstance (y_pred , TensorLike ):
180+ # Tensor case: check dimensions
181+ if y_pred .dim () == y_true .dim ():
182+ # Deterministic: same shape as y_true
183+ # CRPS requires at least 2 ensemble members, so duplicate the prediction
184+ samples = y_pred .unsqueeze (- 1 ).repeat_interleave (2 , dim = - 1 )
185+ elif y_pred .dim () == y_true .dim () + 1 :
186+ # Probabilistic: already has sample dimension at end
187+ samples = y_pred
188+ if samples .shape [:- 1 ] != y_true .shape :
189+ raise ValueError (
190+ f"y_pred shape { samples .shape [:- 1 ]} (excluding last dimension) "
191+ f"does not match y_true shape { y_true .shape } "
192+ )
193+ else :
194+ raise ValueError (
195+ f"y_pred dimensions ({ y_pred .dim ()} ) incompatible with y_true "
196+ f"dimensions ({ y_true .dim ()} ). Expected same dimensions or "
197+ f"y_true.dim() + 1"
198+ )
199+ else :
200+ raise ValueError (
201+ f"y_pred must be a tensor or distribution, got { type (y_pred )} "
202+ )
203+
204+ # Flatten batch and target dimensions
205+ samples_flat = samples .flatten (end_dim = - 2 ) # (batch * targets, n_samples)
206+ y_true_flat = y_true .flatten () # (batch * targets,)
207+
208+ # ContinuousRankedProbabilityScore computes mean by default
209+ return crps_metric (samples_flat , y_true_flat )
210+
211+
74212R2 = TorchMetrics (
75213 metric = torchmetrics .R2Score ,
76214 name = "r2" ,
@@ -95,11 +233,14 @@ def __call__(self, y_pred: OutputLike, y_true: TensorLike) -> TensorLike:
95233 maximize = False ,
96234)
97235
236+ CRPS = CRPSMetric ()
237+
98238AVAILABLE_METRICS = {
99239 "r2" : R2 ,
100240 "rmse" : RMSE ,
101241 "mse" : MSE ,
102242 "mae" : MAE ,
243+ "crps" : CRPS ,
103244}
104245
105246
0 commit comments