Skip to content

Commit bf210f5

Browse files
authored
Merge pull request #904 from alan-turing-institute/884-crps-metric
Add CRPS metric (#884)
2 parents 3bc0d86 + db288c5 commit bf210f5

File tree

6 files changed

+362
-24
lines changed

6 files changed

+362
-24
lines changed

autoemulate/core/compare.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from autoemulate.core.device import TorchDeviceMixin
1515
from autoemulate.core.logging_config import get_configured_logger
1616
from autoemulate.core.metrics import (
17+
R2,
1718
TorchMetrics,
1819
get_metric_config,
1920
get_metric_configs,
@@ -683,7 +684,7 @@ def plot( # noqa: PLR0912, PLR0915
683684

684685
# Re-run prediction with just this model to get the predictions
685686
y_pred, y_variance = model.predict_mean_and_variance(test_x)
686-
r2_score = evaluate(y_pred, test_y)
687+
r2_score = evaluate(y_pred, test_y, metric=R2)
687688

688689
# Handle ranges
689690
input_ranges = input_ranges or {}

autoemulate/core/metrics.py

Lines changed: 150 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,39 @@
77
from functools import partial
88

99
import torchmetrics
10-
11-
from autoemulate.core.types import OutputLike, TensorLike, TorchMetricsLike
10+
from einops import rearrange
11+
from torchmetrics.regression.crps import ContinuousRankedProbabilityScore
12+
13+
from autoemulate.core.types import (
14+
DistributionLike,
15+
OutputLike,
16+
TensorLike,
17+
TorchMetricsLike,
18+
)
1219

1320

1421
class Metric:
1522
"""Configuration for a single metric.
1623
1724
Parameters
1825
----------
19-
name : str
26+
name: str
2027
Display name for the metric.
21-
maximize : bool
28+
maximize: bool
2229
Whether higher values are better. Defaults to True.
2330
"""
2431

2532
name: str
2633
maximize: bool
2734

2835
def __repr__(self) -> str:
29-
"""Return the string representation of the MetricConfig."""
30-
return f"MetricConfig(name={self.name}, maximize={self.maximize})"
36+
"""Return the string representation of the Metric."""
37+
return f"Metric(name={self.name}, maximize={self.maximize})"
3138

3239
@abstractmethod
33-
def __call__(self, y_pred: OutputLike, y_true: TensorLike) -> TensorLike:
40+
def __call__(
41+
self, y_pred: OutputLike, y_true: TensorLike, n_samples: int = 1000
42+
) -> TensorLike:
3443
"""Calculate metric."""
3544

3645

@@ -57,20 +66,149 @@ def __init__(
5766
self.name = name
5867
self.maximize = maximize
5968

60-
def __call__(self, y_pred: OutputLike, y_true: TensorLike) -> TensorLike:
69+
def __call__(
70+
self, y_pred: OutputLike, y_true: TensorLike, n_samples: int = 1000
71+
) -> TensorLike:
6172
"""Calculate metric."""
62-
if not isinstance(y_pred, TensorLike):
73+
if not isinstance(y_pred, OutputLike):
6374
raise ValueError(f"Metric not implemented for y_pred ({type(y_pred)})")
6475
if not isinstance(y_true, TensorLike):
6576
raise ValueError(f"Metric not implemented for y_true ({type(y_true)})")
6677

78+
# Handle probabilistic predictions
79+
if isinstance(y_pred, DistributionLike):
80+
try:
81+
y_pred = y_pred.mean
82+
except Exception:
83+
y_pred = y_pred.rsample((n_samples,)).mean(dim=0)
6784
metric = self.metric()
6885
metric.to(y_pred.device)
6986
# Assume first dim is a batch dim, flatten others for metric calculation
7087
metric.update(y_pred.flatten(start_dim=1), y_true.flatten(start_dim=1))
7188
return metric.compute()
7289

7390

91+
class ProbabilisticMetric(Metric):
92+
"""Base class for probabilistic metrics."""
93+
94+
@abstractmethod
95+
def __call__(
96+
self, y_pred: OutputLike, y_true: TensorLike, n_samples: int = 1000
97+
) -> TensorLike:
98+
"""Calculate metric."""
99+
100+
101+
class CRPSMetric(ProbabilisticMetric):
102+
"""Continuous Ranked Probability Score (CRPS) metric.
103+
104+
CRPS is a scoring rule for evaluating probabilistic predictions. It reduces to mean
105+
absolute error (MAE) for deterministic predictions and generalizes to distributions
106+
by measuring the integral difference between predicted and actual CDFs.
107+
108+
The metric aggregates over batch and target dimensions by computing the mean
109+
CRPS across all scalar outputs, making it comparable across different batch
110+
sizes and output dimensions.
111+
112+
Attributes
113+
----------
114+
name: str
115+
Display name for the metric.
116+
maximize: bool
117+
Whether higher values are better. False for CRPS (lower is better).
118+
"""
119+
120+
name: str = "crps"
121+
maximize: bool = False
122+
123+
def __call__(
124+
self, y_pred: OutputLike, y_true: TensorLike, n_samples: int = 1000
125+
) -> TensorLike:
126+
"""Calculate CRPS metric.
127+
128+
The metric handles both deterministic predictions (tensors) and probabilistic
129+
predictions (tensors of samples or distributions).
130+
131+
Aggregation across batch and target dimensions is performed by computing the
132+
mean CRPS across all scalar outputs. This makes the metric comparable across
133+
different batch sizes and target dimensions.
134+
135+
Parameters
136+
----------
137+
y_pred: OutputLike
138+
Predicted outputs. Can be a tensor or a distribution.
139+
- If tensor with shape `(batch_size, *target_shape)`: treated as
140+
deterministic prediction (reduces to MAE).
141+
- If tensor with shape `(batch_size, *target_shape, n_samples)`: treated as
142+
samples from a probabilistic prediction.
143+
- If distribution: `n_samples` are drawn to estimate CRPS.
144+
y_true: TensorLike
145+
True target values of shape `(batch_size, *target_shape)`.
146+
n_samples: int
147+
Number of samples to draw from the predicted distribution if `y_pred` is a
148+
distribution. Defaults to 1000.
149+
150+
Returns
151+
-------
152+
TensorLike
153+
Mean CRPS score across all batch elements and target dimensions.
154+
155+
Raises
156+
------
157+
ValueError
158+
If input types or shapes are incompatible.
159+
"""
160+
if not isinstance(y_true, TensorLike):
161+
raise ValueError(f"y_true must be a tensor, got {type(y_true)}")
162+
163+
# Ensure 2D y_true for consistent handling
164+
y_true = y_true.unsqueeze(-1) if y_true.ndim == 1 else y_true
165+
166+
# Initialize CRPS metric (computes mean by default)
167+
crps_metric = ContinuousRankedProbabilityScore()
168+
crps_metric.to(y_true.device)
169+
170+
# Handle different prediction types
171+
if isinstance(y_pred, DistributionLike):
172+
# Distribution case: sample from it
173+
samples = rearrange(y_pred.sample((n_samples,)), "s b ... -> b ... s")
174+
if samples.shape[:-1] != y_true.shape:
175+
raise ValueError(
176+
f"Sampled predictions shape {samples.shape[:-1]} (excluding sample "
177+
f"dimension) does not match y_true shape {y_true.shape}"
178+
)
179+
elif isinstance(y_pred, TensorLike):
180+
# Tensor case: check dimensions
181+
if y_pred.dim() == y_true.dim():
182+
# Deterministic: same shape as y_true
183+
# CRPS requires at least 2 ensemble members, so duplicate the prediction
184+
samples = y_pred.unsqueeze(-1).repeat_interleave(2, dim=-1)
185+
elif y_pred.dim() == y_true.dim() + 1:
186+
# Probabilistic: already has sample dimension at end
187+
samples = y_pred
188+
if samples.shape[:-1] != y_true.shape:
189+
raise ValueError(
190+
f"y_pred shape {samples.shape[:-1]} (excluding last dimension) "
191+
f"does not match y_true shape {y_true.shape}"
192+
)
193+
else:
194+
raise ValueError(
195+
f"y_pred dimensions ({y_pred.dim()}) incompatible with y_true "
196+
f"dimensions ({y_true.dim()}). Expected same dimensions or "
197+
f"y_true.dim() + 1"
198+
)
199+
else:
200+
raise ValueError(
201+
f"y_pred must be a tensor or distribution, got {type(y_pred)}"
202+
)
203+
204+
# Flatten batch and target dimensions
205+
samples_flat = samples.flatten(end_dim=-2) # (batch * targets, n_samples)
206+
y_true_flat = y_true.flatten() # (batch * targets,)
207+
208+
# ContinuousRankedProbabilityScore computes mean by default
209+
return crps_metric(samples_flat, y_true_flat)
210+
211+
74212
R2 = TorchMetrics(
75213
metric=torchmetrics.R2Score,
76214
name="r2",
@@ -95,11 +233,14 @@ def __call__(self, y_pred: OutputLike, y_true: TensorLike) -> TensorLike:
95233
maximize=False,
96234
)
97235

236+
CRPS = CRPSMetric()
237+
98238
AVAILABLE_METRICS = {
99239
"r2": R2,
100240
"rmse": RMSE,
101241
"mse": MSE,
102242
"mae": MAE,
243+
"crps": CRPS,
103244
}
104245

105246

autoemulate/core/model_selection.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from autoemulate.core.types import (
1515
DeviceLike,
1616
ModelParams,
17+
OutputLike,
1718
TensorLike,
1819
TransformedEmulatorParams,
1920
)
@@ -25,27 +26,31 @@
2526

2627

2728
def evaluate(
28-
y_pred: TensorLike,
29+
y_pred: OutputLike,
2930
y_true: TensorLike,
3031
metric: Metric = R2,
32+
n_samples: int = 1000,
3133
) -> float:
3234
"""
3335
Evaluate Emulator prediction performance using a `torchmetrics.Metric`.
3436
3537
Parameters
3638
----------
39+
y_pred: OutputLike
40+
Predicted target values, as returned by an Emulator.
3741
y_true: TensorLike
3842
Ground truth target values.
39-
y_pred: TensorLike
40-
Predicted target values, as returned by an Emulator.
4143
metric: Metric
4244
Metric to use for evaluation. Defaults to R2.
45+
n_samples: int
46+
Number of samples to generate to predict mean when y_pred does not have a mean
47+
directly available. Defaults to 1000.
4348
4449
Returns
4550
-------
4651
float
4752
"""
48-
return metric(y_pred, y_true).item()
53+
return metric(y_pred, y_true, n_samples=n_samples).item()
4954

5055

5156
def cross_validate(
@@ -139,7 +144,7 @@ def cross_validate(
139144
transformed_emulator.fit(x, y)
140145

141146
# compute and save results
142-
y_pred = transformed_emulator.predict_mean(x_val)
147+
y_pred = transformed_emulator.predict(x_val)
143148
for metric in metrics:
144149
score = evaluate(y_pred, y_val, metric)
145150
cv_results[metric.name].append(score)
@@ -151,7 +156,7 @@ def bootstrap(
151156
x: TensorLike,
152157
y: TensorLike,
153158
n_bootstraps: int | None = 100,
154-
n_samples: int = 100,
159+
n_samples: int = 1000,
155160
device: str | torch.device = "cpu",
156161
metrics: list[TorchMetrics] | None = None,
157162
) -> dict[str, tuple[float, float]]:
@@ -172,7 +177,7 @@ def bootstrap(
172177
Defaults to 100.
173178
n_samples: int
174179
Number of samples to generate to predict mean when emulator does not have a
175-
mean directly available. Defaults to 100.
180+
mean directly available. Defaults to 1000.
176181
device: str | torch.device
177182
The device to use for computations. Default is "cpu".
178183
metrics: list[MetricConfig] | None
@@ -192,10 +197,10 @@ def bootstrap(
192197

193198
# If no bootstraps are specified, fall back to a single evaluation on given data
194199
if n_bootstraps is None:
195-
y_pred = model.predict_mean(x, n_samples=n_samples)
200+
y_pred = model.predict(x)
196201
results = {}
197202
for metric in metrics:
198-
score = evaluate(y_pred, y, metric)
203+
score = evaluate(y_pred, y, metric=metric, n_samples=n_samples)
199204
results[metric.name] = (score, float("nan"))
200205
return results
201206

@@ -213,11 +218,13 @@ def bootstrap(
213218
y_bootstrap = y[idxs]
214219

215220
# Make predictions
216-
y_pred = model.predict_mean(x_bootstrap, n_samples=n_samples)
221+
y_pred = model.predict(x_bootstrap)
217222

218223
# Compute metrics for this bootstrap sample
219224
for metric in metrics:
220-
metric_scores[metric.name][i] = evaluate(y_pred, y_bootstrap, metric)
225+
metric_scores[metric.name][i] = evaluate(
226+
y_pred, y_bootstrap, metric=metric, n_samples=n_samples
227+
)
221228

222229
# Return mean and std for each metric
223230
return {

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ dependencies = [
3636
"torchrbf>=0.0.1",
3737
"arviz>=0.21.0",
3838
"getdist>=1.7.2",
39+
"einops>=0.8.1",
3940
]
4041

4142
[project.urls]

0 commit comments

Comments
 (0)